From a95bc22e648d01158d3a4fd597059135e1302266 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 5 Aug 2014 11:17:28 -0700 Subject: [PATCH 01/34] timing for DecisionTree internals --- .../spark/mllib/tree/DecisionTree.scala | 80 ++++++++++++++++++- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36cf..1330d9c891dfa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import java.util.Calendar + import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint @@ -29,6 +31,40 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +class TimeTracker { + + var tmpTime: Long = Calendar.getInstance().getTimeInMillis + + def reset(): Unit = { + tmpTime = Calendar.getInstance().getTimeInMillis + } + + def elapsed(): Long = { + Calendar.getInstance().getTimeInMillis - tmpTime + } + + var initTime: Long = 0 // Data retag and cache + var findSplitsBinsTime: Long = 0 + var extractNodeInfoTime: Long = 0 + var extractInfoForLowerLevelsTime: Long = 0 + var findBestSplitsTime: Long = 0 + var findBinsForLevelTime: Long = 0 + var binAggregatesTime: Long = 0 + var chooseSplitsTime: Long = 0 + + override def toString: String = { + s"DecisionTree timing\n" + + s"initTime: $initTime\n" + + s"findSplitsBinsTime: $findSplitsBinsTime\n" + + s"extractNodeInfoTime: $extractNodeInfoTime\n" + + s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" + + s"findBestSplitsTime: $findBestSplitsTime\n" + + s"findBinsForLevelTime: $findBinsForLevelTime\n" + + s"binAggregatesTime: $binAggregatesTime\n" + + s"chooseSplitsTime: $chooseSplitsTime\n" + } +} + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -47,16 +83,24 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + val timer = new TimeTracker() + timer.reset() + // Cache input RDD for speedup during multiple passes. val retaggedInput = input.retag(classOf[LabeledPoint]).cache() logDebug("algo = " + strategy.algo) + timer.initTime += timer.elapsed() + timer.reset() + // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) + timer.findSplitsBinsTime += timer.elapsed() + // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -98,6 +142,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * still survived the filters of the parent nodes. */ + var findBestSplitsTime: Long = 0 + var extractNodeInfoTime: Long = 0 + var extractInfoForLowerLevelsTime: Long = 0 + var level = 0 var break = false while (level <= maxDepth && !break) { @@ -106,16 +154,23 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("level = " + level) logDebug("#####################################") + // Find best split for all nodes at a level. + timer.reset() val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + strategy, level, filters, splits, bins, timer, maxLevelForSingleGroup) + timer.findBestSplitsTime += timer.elapsed() for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + timer.reset() // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.extractNodeInfoTime += timer.elapsed() + timer.reset() // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) + timer.extractInfoForLowerLevelsTime += timer.elapsed() logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) @@ -129,6 +184,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + println(timer) + logDebug("#####################################") logDebug("Extracting tree model") logDebug("#####################################") @@ -194,6 +251,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + object DecisionTree extends Serializable with Logging { /** @@ -325,6 +383,7 @@ object DecisionTree extends Serializable with Logging { filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { @@ -339,13 +398,13 @@ object DecisionTree extends Serializable with Logging { var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + filters, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) } } @@ -372,6 +431,7 @@ object DecisionTree extends Serializable with Logging { filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -628,9 +688,13 @@ object DecisionTree extends Serializable with Logging { arr } - // Find feature bins for all nodes at a level. + timer.reset() + + // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) + timer.findBinsForLevelTime += timer.elapsed() + /** * Increment aggregate in location for (node, feature, bin, label). * @@ -873,12 +937,16 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } + timer.reset() + // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) + timer.binAggregatesTime += timer.elapsed() + /** * Calculates the information gain for all splits based upon left/right split aggregates. * @param leftNodeAgg left node aggregates @@ -1282,6 +1350,8 @@ object DecisionTree extends Serializable with Logging { } } + timer.reset() + // Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level @@ -1295,6 +1365,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.chooseSplitsTime += timer.elapsed() + bestSplits } From 3211f027c1a41f8eaa4eea4e90073216a8474c4e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 8 Aug 2014 09:46:12 -0700 Subject: [PATCH 02/34] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) --- .../spark/mllib/tree/DecisionTree.scala | 159 +++------------- .../spark/mllib/tree/impl/TreePoint.scala | 180 ++++++++++++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 72 +++++-- 3 files changed, 263 insertions(+), 148 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f738418a6c431..0d651d5441782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impl.TreePoint import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -92,7 +93,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.reset() // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) timer.initTime += timer.elapsed() @@ -106,6 +107,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.findSplitsBinsTime += timer.elapsed() + timer.reset() + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + timer.initTime += timer.elapsed() + // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree @@ -162,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.reset() - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, timer, maxLevelForSingleGroup) + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) timer.findBestSplitsTime += timer.elapsed() for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { @@ -463,7 +468,7 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -475,22 +480,22 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - timer: TimeTracker, - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. @@ -510,7 +515,7 @@ object DecisionTree extends Serializable with Logging { /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree @@ -523,7 +528,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -601,7 +606,7 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false @@ -609,20 +614,20 @@ object DecisionTree extends Serializable with Logging { // Apply each filter and check sample validity. Return false when invalid condition found. for (filter <- parentFilters) { - val features = labeledPoint.features val featureIndex = filter.split.feature - val threshold = filter.split.threshold val comparison = filter.comparison - val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) + val binId = treePoint.features(featureIndex) + val bin = bins(featureIndex)(binId) if (isFeatureContinuous) { + val featureValue = bin.highSplit.threshold + val threshold = filter.split.threshold comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + case -1 => if (featureValue > threshold) return false + case 1 => if (featureValue <= threshold) return false } } else { - val containsFeature = categories.contains(feature) + val containsFeature = filter.split.categories.contains(bin.category) comparison match { case -1 => if (!containsFeature) return false case 1 => if (containsFeature) return false @@ -635,102 +640,7 @@ object DecisionTree extends Serializable with Logging { true } - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 - } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex - } - } + // TODO: REMOVED findBin() /** * Finds bins for all nodes (and all features) at a given level. @@ -748,37 +658,26 @@ object DecisionTree extends Serializable with Logging { * bin index for this labeledPoint * (or InvalidBinIndex if labeledPoint is not handled by this node) */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(treePoint: TreePoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. - arr(0) = labeledPoint.label + arr(0) = treePoint.label // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) + val sampleValid = isSampleValid(parentFilters, treePoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 + // TODO: Vectorize this while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } + arr(shift + featureIndex) = treePoint.features(featureIndex) featureIndex += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..f3b5dce041207 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + +/** + * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). + * TODO: ADD DOC + */ +private[tree] class TreePoint(val label: Double, val features: Array[Int]) { +} + +private[tree] object TreePoint { + + def convertToTreeRDD( + input: RDD[LabeledPoint], + strategy: Strategy, + bins: Array[Array[Bin]]): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, + strategy.categoricalFeaturesInfo) + } + } + + def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + isMulticlassClassification: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 // offset by 1 for label + while (featureIndex < numFeatures) { + val featureInfo = categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, + bins, categoricalFeaturesInfo) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isUnorderedFeature, bins, categoricalFeaturesInfo) + } + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + + /** + * Find bin for one (labeledPoint, feature). + * + * @param featureIndex + * @param labeledPoint + * @param isFeatureContinuous + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] + * of size (numFeatures, numBins). + * @param categoricalFeaturesInfo + * @return + */ + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } + else if (lowThreshold >= feature) { + right = mid - 1 + } + else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..5666064647a10 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree +import org.apache.spark.mllib.tree.impl.TreePoint + import scala.collection.JavaConverters._ import org.scalatest.FunSuite @@ -41,6 +43,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length + if (accuracy < requiredAccuracy) { + println(s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } assert(accuracy >= requiredAccuracy) } @@ -427,7 +432,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -454,7 +460,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 @@ -499,7 +506,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,7 +529,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -544,7 +553,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -567,7 +577,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,7 +607,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,7 +616,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) @@ -630,7 +642,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -676,6 +689,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } test("stump with categorical variables for multiclass classification, with just enough bins") { + println("START: stump with categorical variables for multiclass classification, with just enough bins") val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -683,14 +697,26 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val model = DecisionTree.train(input, strategy) - validateClassifier(model, arr, 1.0) - assert(model.numNodes === 3) - assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) + println(s"splits:") + for (feature <- Range(0,splits.size)) { + for (i <- Range(0,3)) { + println(s" f:$feature [$i]: ${splits(feature)(i)}") + } + } + println(s"bins:") + for (feature <- Range(0,bins.size)) { + for (i <- Range(0,4)) { + println(s" f:$feature [$i]: ${bins(feature)(i)}") + } + } + println(s"bestSplits:") + bestSplits.foreach { x => + println(s"\t $x") + } assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -701,6 +727,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = bestSplits(0)._2 assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) + + val model = DecisionTree.train(input, strategy) + println(model) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + println("END: stump with categorical variables for multiclass classification, with just enough bins") } test("stump with continuous variables for multiclass classification") { @@ -714,7 +747,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -738,7 +772,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) @@ -757,7 +792,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) From 0f676e2e0ae02e54387a255ac9f64d3c7265d152 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 8 Aug 2014 14:12:52 -0700 Subject: [PATCH 03/34] Optimizations + Bug fix for DecisionTree Optimization: Added TreePoint representation so we only call findBin once for each example, feature. Also, calculateGainsForAllNodeSplits now only searches over actual splits, not empty/unused ones. BUG FIX: isSampleValid * isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories. * exhibited for unordered features (multi-class classification with categorical features of low arity) * Fix: Index bins correctly for unordered categorical features. Also: some commented-out debugging println calls in DecisionTree, to be removed later --- .../spark/mllib/tree/DecisionTree.scala | 123 ++++++++++++++---- 1 file changed, 95 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0d651d5441782..17e7a3e65db60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -252,6 +252,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") for (filter <- filters(nodeIndex)) { logDebug("Filter = " + filter) } @@ -477,7 +478,7 @@ object DecisionTree extends Serializable with Logging { * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( input: RDD[TreePoint], @@ -490,6 +491,7 @@ object DecisionTree extends Serializable with Logging { maxLevelForSingleGroup: Int, timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation + //println(s"findBestSplits: level = $level") if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups @@ -617,9 +619,9 @@ object DecisionTree extends Serializable with Logging { val featureIndex = filter.split.feature val comparison = filter.comparison val isFeatureContinuous = filter.split.featureType == Continuous - val binId = treePoint.features(featureIndex) - val bin = bins(featureIndex)(binId) if (isFeatureContinuous) { + val binId = treePoint.features(featureIndex) + val bin = bins(featureIndex)(binId) val featureValue = bin.highSplit.threshold val threshold = filter.split.threshold comparison match { @@ -627,12 +629,22 @@ object DecisionTree extends Serializable with Logging { case 1 => if (featureValue <= threshold) return false } } else { - val containsFeature = filter.split.categories.contains(bin.category) + val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 + val isUnorderedFeature = + isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits + val featureValue = if (isUnorderedFeature) { + treePoint.features(featureIndex) + } else { + val binId = treePoint.features(featureIndex) + bins(featureIndex)(binId).category + } + val containsFeature = filter.split.categories.contains(featureValue) comparison match { case -1 => if (!containsFeature) return false case 1 => if (containsFeature) return false } - } } @@ -669,6 +681,7 @@ object DecisionTree extends Serializable with Logging { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, treePoint) + //println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}") val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. @@ -739,6 +752,7 @@ object DecisionTree extends Serializable with Logging { label: Double, agg: Array[Double], rightChildShift: Int): Unit = { + //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") // Find the bin index for this feature. val arrIndex = 1 + numFeatures * nodeIndex + featureIndex val featureValue = arr(arrIndex).toInt @@ -792,6 +806,8 @@ object DecisionTree extends Serializable with Logging { } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * @@ -814,8 +830,11 @@ object DecisionTree extends Serializable with Logging { // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (level == 1) { + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift + //println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}") + } if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes // actual class label val label = arr(0) // Iterate over all features. @@ -874,7 +893,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + agg(aggIndex + 2) = agg(aggIndex + 2) + label * label featureIndex += 1 } } @@ -944,6 +963,29 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregates.length = " + binAggregates.length) timer.binAggregatesTime += timer.elapsed() + //2 * numClasses * numBins * numFeatures * numNodes for unordered features. + // (left/right, node, feature, bin, label) + /* + println(s"binAggregates:") + for (i <- Range(0,2)) { + for (n <- Range(0,numNodes)) { + for (f <- Range(0,numFeatures)) { + for (b <- Range(0,4)) { + for (c <- Range(0,numClasses)) { + val idx = i * numClasses * numBins * numFeatures * numNodes + + n * numClasses * numBins * numFeatures + + f * numBins * numFeatures + + b * numFeatures + + c + if (binAggregates(idx) != 0) { + println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}") + } + } + } + } + } + } + */ /** * Calculates the information gain for all splits based upon left/right split aggregates. @@ -985,6 +1027,7 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftTotalCount + rightTotalCount if (totalCount == 0) { // Return arbitrary prediction. + //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) } @@ -997,13 +1040,23 @@ object DecisionTree extends Serializable with Logging { def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } } - if (result._1 < 0) 0 else result._1 + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") + } + result._1 } val predict = indexOfLargestArrayElement(leftRightCounts) + if (predict == 0 && featureIndex == 0 && splitIndex == 0) { + //println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}") + } val prob = leftRightCounts(predict) / totalCount val leftImpurity = if (leftTotalCount == 0) { @@ -1023,6 +1076,7 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) @@ -1140,6 +1194,7 @@ object DecisionTree extends Serializable with Logging { val rightChildShift = numClasses * numBins * numFeatures var splitIndex = 0 + var TMPDEBUG = 0.0 while (splitIndex < numBins - 1) { var classIndex = 0 while (classIndex < numClasses) { @@ -1149,10 +1204,12 @@ object DecisionTree extends Serializable with Logging { val rightBinValue = binData(rightChildShift + shift + classIndex) leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + TMPDEBUG += leftBinValue + rightBinValue classIndex += 1 } splitIndex += 1 } + //println(s"found Agg: $TMPDEBUG") } def findAggForRegression( @@ -1247,7 +1304,8 @@ object DecisionTree extends Serializable with Logging { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + for (splitIndex <- 0 until numSplitsForFeature) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) } @@ -1255,6 +1313,27 @@ object DecisionTree extends Serializable with Logging { gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits = + numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { + // Ordered features + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1273,7 +1352,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1283,27 +1362,14 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex + //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") } splitIndex += 1 } @@ -1361,6 +1427,7 @@ object DecisionTree extends Serializable with Logging { val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + //println(s"bestSplits(node:$node): ${bestSplits(node)}") node += 1 } timer.chooseSplitsTime += timer.elapsed() From b914f3b7ed94e897b55f28c772f48a7d6fba7f06 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 9 Aug 2014 12:01:45 -0700 Subject: [PATCH 04/34] DecisionTree optimization: eliminated filters + small changes DecisionTree.scala * Eliminated filters, replaced by building tree on the fly and filtering top-down. ** Aggregation over examples now skips examples which do not reach the current level. * Only calculate unorderedFeatures once (in findSplitsBins) Node: Renamed predictIfLeaf to predict Bin, Split: Updated doc --- .../spark/mllib/tree/DecisionTree.scala | 348 +++++++++++++----- .../apache/spark/mllib/tree/model/Bin.scala | 18 +- .../mllib/tree/model/DecisionTreeModel.scala | 2 +- .../apache/spark/mllib/tree/model/Node.scala | 16 +- .../apache/spark/mllib/tree/model/Split.scala | 5 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 88 ++--- 6 files changed, 326 insertions(+), 151 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 17e7a3e65db60..be57ae7c91832 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,7 +19,10 @@ package org.apache.spark.mllib.tree import java.util.Calendar +import org.apache.spark.mllib.linalg.Vector + import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -101,7 +104,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -116,13 +119,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // the max number of nodes possible given the depth of the tree val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 // Initialize an array to hold filters applied to points for each node. - val filters = new Array[List[Filter]](maxNumNodes) + //val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. - filters(0) = List() + //filters(0) = List() // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) + val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? + nodesInTree(0) = true // num features val numFeatures = retaggedInput.take(1)(0).features.size @@ -168,23 +173,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.reset() val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) + strategy, level, nodes, splits, bins, maxLevelForSingleGroup, unorderedFeatures, timer) timer.findBestSplitsTime += timer.elapsed() + val levelNodeIndexOffset = math.pow(2, level).toInt - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + val nodeIndex = levelNodeIndexOffset + index + val isLeftChild = level != 0 && nodeIndex % 2 == 1 + val parentNodeIndex = if (isLeftChild) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) + // TODO: Use above check to skip unused branch of tree + // Extract info for this node (index) at the current level. timer.reset() - // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) timer.extractNodeInfoTime += timer.elapsed() - timer.reset() + if (level != 0) { + // Set parent. + if (isLeftChild) { + nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) + } else { + nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) + } + } // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) + timer.reset() + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) timer.extractInfoForLowerLevelsTime += timer.elapsed() logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. + println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) { @@ -233,30 +256,32 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double], - filters: Array[List[Filter]]): Unit = { + parentImpurities: Array[Double]): Unit = { + if (level >= maxDepth) + return + //filters: Array[List[Filter]]): Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth) { - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + // noting the parent impurities + parentImpurities(nodeIndex) = impurity + // noting the parents filters for the child nodes + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + /* + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) } + */ i += 1 } } @@ -474,21 +499,23 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. + * @param unorderedFeatures Set of unordered (categorical) features. * @return array (over nodes) of splits with best split for each node at a given level. + * TODO: UPDATE DOC */ protected[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, + unorderedFeatures: Set[Int], timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation //println(s"findBestSplits: level = $level") @@ -504,13 +531,14 @@ object DecisionTree extends Serializable with Logging { var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, timer, numGroups, groupIndex) + nodes, splits, bins, unorderedFeatures, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer) + findBestSplitsPerGroup(input, parentImpurities, strategy, level, nodes, splits, bins, + unorderedFeatures, timer) } } @@ -522,21 +550,23 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features - * @param bins possible bins for all features + * @param bins possible bins for all features, indexed as (numFeatures)(numBins) + * @param unorderedFeatures Set of unordered (categorical) features. * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. + * TODO: UPDATE DOC */ private def findBestSplitsPerGroup( input: RDD[TreePoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], + unorderedFeatures: Set[Int], timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -595,6 +625,7 @@ object DecisionTree extends Serializable with Logging { val groupShift = numNodes * groupIndex /** Find the filters used before reaching the current code. */ + /* def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -603,11 +634,13 @@ object DecisionTree extends Serializable with Logging { filters(nodeFilterIndex) } } + */ /** * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ + /* def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { @@ -651,6 +684,79 @@ object DecisionTree extends Serializable with Logging { // Return true when the sample is valid for all filters. true } + */ + + /** + * Get the node index corresponding to this data point. + * This is used during training, mimicking prediction. + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + */ + def predictNodeIndex(node: Node, features: Array[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = features(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold + } + case Categorical => { + val featureValue = if (unorderedFeatures.contains(featureIndex)) { + features(featureIndex) + } else { + val binIndex = features(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + node.id * 2 + 1 // left + } else { + node.id * 2 + 2 // right + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, features) + } else { + predictNodeIndex(node.rightNode.get, features) + } + } + } + } + + def nodeIndexToLevel(idx: Int): Int = { + if (idx == 0) { + 0 + } else { + math.floor(math.log(idx) / math.log(2)).toInt + } + } + + // Used for treePointToNodeIndex + val levelOffset = (math.pow(2, level) - 1).toInt + + /** + * Find the node (indexed from 0 at the start of this level) for the given example. + * If the example does not reach this level, returns a value < 0. + */ + def treePointToNodeIndex(treePoint: TreePoint): Int = { + if (level == 0) { + 0 + } else { + val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.features) + // Get index for this level. + globalNodeIndex - levelOffset + } + } // TODO: REMOVED findBin() @@ -670,6 +776,7 @@ object DecisionTree extends Serializable with Logging { * bin index for this labeledPoint * (or InvalidBinIndex if labeledPoint is not handled by this node) */ + /* def findBinsForLevel(treePoint: TreePoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) @@ -688,7 +795,6 @@ object DecisionTree extends Serializable with Logging { arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - // TODO: Vectorize this while (featureIndex < numFeatures) { arr(shift + featureIndex) = treePoint.features(featureIndex) featureIndex += 1 @@ -698,38 +804,43 @@ object DecisionTree extends Serializable with Logging { } arr } + */ timer.reset() // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) + //val binMappedRDD = input.map(x => findBinsForLevel(x)) timer.findBinsForLevelTime += timer.elapsed() /** * Increment aggregate in location for (node, feature, bin, label). * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * TODO: UPDATE DOC */ def updateBinForOrderedFeature( - arr: Array[Double], + treePoint: TreePoint, agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int): Unit = { // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex + //val arrShift = 1 + numFeatures * nodeIndex + //val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * arr(arrIndex).toInt + - label.toInt + numClasses * treePoint.features(featureIndex) + //numClasses * arr(arrIndex).toInt + + treePoint.label.toInt + if (aggIndex < 0 || aggIndex >= agg.size) { + val binIndex = treePoint.features(featureIndex) + println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures, level = $level") + } agg(aggIndex) += 1 } @@ -738,29 +849,30 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * - * @param arr arr(0) = label. - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. * @param rightChildShift Offset for right side of agg. + * TODO: UPDATE DOC + * TODO: Make arg order same as for ordered feature. */ def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, - arr: Array[Double], - label: Double, + treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") // Find the bin index for this feature. - val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - val featureValue = arr(arrIndex).toInt + //val arrIndex = 1 + numFeatures * nodeIndex + featureIndex + //val featureValue = arr(arrIndex).toInt + val featureValue = treePoint.features(featureIndex) // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - label.toInt + treePoint.label.toInt // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 @@ -779,13 +891,23 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { + def binaryOrNotCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + featureIndex += 1 + } + /* // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -804,6 +926,7 @@ object DecisionTree extends Serializable with Logging { } nodeIndex += 1 } + */ } val rightChildShift = numClasses * numBins * numFeatures * numNodes @@ -811,35 +934,49 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). - * For ordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. - * For unordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). * @param agg Array storing aggregate calculation. * For ordered features, this is of size: * numClasses * numBins * numFeatures * numNodes. * For unordered features, this is of size: * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. + def multiclassWithCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + } else { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + } + featureIndex += 1 + } + /* + // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (level == 1) { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - //println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}") - } if (isSampleValidForNode) { // actual class label val label = arr(0) // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, + rightChildShift) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } + //------ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) @@ -854,11 +991,13 @@ object DecisionTree extends Serializable with Logging { updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } } + //------ featureIndex += 1 } } nodeIndex += 1 } + */ } /** @@ -868,12 +1007,31 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size: 3 * numBins * numFeatures * numNodes - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. + def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { + // TODO: Move stuff outside loop. + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Update count, sum, and sum^2 for one bin. + val binIndex = treePoint.features(featureIndex) + val aggIndex = + 3 * numBins * numFeatures * nodeIndex + + 3 * numBins * featureIndex + + 3 * binIndex + if (aggIndex >= agg.size) { + println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures") + } + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label * label + featureIndex += 1 + } + /* var nodeIndex = 0 while (nodeIndex < numNodes) { // Check whether the instance was valid for this nodeIndex. @@ -899,6 +1057,7 @@ object DecisionTree extends Serializable with Logging { } nodeIndex += 1 } + */ } /** @@ -916,19 +1075,21 @@ object DecisionTree extends Serializable with Logging { * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. * Size for regression: * 3 * numBins * numFeatures * numNodes. - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(arr, agg) - } else { - binaryOrNotCategoricalBinSeqOp(arr, agg) - } - case Regression => regressionBinSeqOp(arr, agg) + def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + val nodeIndex = treePointToNodeIndex(treePoint) + if (nodeIndex >= 0) { // Otherwise, example does not reach this level. + strategy.algo match { + case Classification => + if (isMulticlassClassificationWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) + } else { + binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) + } + case Regression => regressionBinSeqOp(agg, treePoint, nodeIndex) + } } agg } @@ -958,7 +1119,7 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) @@ -1259,6 +1420,12 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + /* if (isMulticlassClassificationWithCategoricalFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { @@ -1276,6 +1443,7 @@ object DecisionTree extends Serializable with Logging { } else { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } + */ featureIndex += 1 } @@ -1323,6 +1491,12 @@ object DecisionTree extends Serializable with Logging { } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + if (unorderedFeatures.contains(featureIndex)) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { + featureCategories + } + /* val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { @@ -1331,6 +1505,7 @@ object DecisionTree extends Serializable with Logging { // Ordered features featureCategories } + */ } } @@ -1481,15 +1656,16 @@ object DecisionTree extends Serializable with Logging { * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for construction the DecisionTree - * @return A tuple of (splits,bins). + * @return A tuple of (splits, bins, unorderedFeatures). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numBins - 1). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] * of size (numFeatures, numBins). + * unorderedFeatures: set of indices for unordered features. */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]], Set[Int]) = { val count = input.count() @@ -1540,6 +1716,9 @@ object DecisionTree extends Serializable with Logging { // Find all splits. + // Record which categorical features will be ordered vs. unordered. + val unorderedFeatures = new mutable.HashSet[Int]() + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { @@ -1566,6 +1745,7 @@ object DecisionTree extends Serializable with Logging { val isUnorderedFeature = isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits if (isUnorderedFeature) { + unorderedFeatures.add(featureIndex) // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1672,7 +1852,7 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (splits,bins) + (splits, bins, unorderedFeatures.toSet) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index c89c1e371a40e..af35d88f713e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. For a continuous - * feature, a bin is determined by a low and a high "split". For a categorical feature, - * the a bin is determined using a single label value (category). + * Used for "binning" the features bins for faster best split calculation. + * + * For a continuous feature, the bin is determined by a low and a high split, + * where an example with featureValue falls into the bin s.t. + * lowSplit.threshold < featureValue <= highSplit.threshold. + * + * For ordered categorical features, there is a 1-1-1 correspondence between + * bins, splits, and feature values. The bin is determined by category/feature value. + * However, the bins are not necessarily ordered by feature value; + * they are ordered using impurity. + * For unordered categorical features, there is a 1-1 correspondence between bins, splits, + * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for binary classification + * @param category categorical label value accepted in the bin for ordered features */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3d3406b5d5f22..0594fd0749d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Vector): Double = { - topNode.predictIfLeaf(features) + topNode.predict(features) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 944f11c2c2e4f..0eee6262781c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -69,24 +69,24 @@ class Node ( /** * predict value if node is not leaf - * @param feature feature value + * @param features feature value * @return predicted value */ - def predictIfLeaf(feature: Vector) : Double = { + def predict(features: Vector) : Double = { if (isLeaf) { predict } else{ if (split.get.featureType == Continuous) { - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (features(split.get.feature) <= split.get.threshold) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } else { - if (split.get.categories.contains(feature(split.get.feature))) { - leftNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(features(split.get.feature))) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index d7ffd386c05ee..50fb48b40de3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * :: DeveloperApi :: * Split applied to a feature * @param feature feature index - * @param threshold threshold for continuous feature + * @param threshold Threshold for continuous feature. + * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous - * @param categories accepted values for categorical variables + * @param categories Split left if categorical feature value is in this set, else right. */ @DeveloperApi case class Split( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5666064647a10..9e6429f2ff108 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -67,7 +67,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -85,7 +85,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -165,7 +165,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) // Check splits. @@ -282,7 +282,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -376,7 +376,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -431,10 +431,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -459,10 +459,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd,strategy) val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -498,7 +498,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -508,7 +508,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -521,7 +521,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -531,7 +531,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -545,7 +545,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -555,7 +555,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -569,7 +569,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -579,7 +579,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -588,12 +588,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } + // TODO: Decide about testing 2nd level + /* test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -609,7 +611,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, - splits, bins, 10) + splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) assert(bestSplits(1)._2.gain > 0) @@ -632,8 +634,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } + */ test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() @@ -641,10 +643,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -689,34 +691,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } test("stump with categorical variables for multiclass classification, with just enough bins") { - println("START: stump with categorical variables for multiclass classification, with just enough bins") val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) - println(s"splits:") - for (feature <- Range(0,splits.size)) { - for (i <- Range(0,3)) { - println(s" f:$feature [$i]: ${splits(feature)(i)}") - } - } - println(s"bins:") - for (feature <- Range(0,bins.size)) { - for (i <- Range(0,4)) { - println(s" f:$feature [$i]: ${bins(feature)(i)}") - } - } - println(s"bestSplits:") - bestSplits.foreach { x => - println(s"\t $x") - } + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -729,11 +715,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.rightImpurity === 0) val model = DecisionTree.train(input, strategy) - println(model) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - println("END: stump with categorical variables for multiclass classification, with just enough bins") } test("stump with continuous variables for multiclass classification") { @@ -746,10 +730,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -771,10 +755,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -791,10 +775,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + new Array[Node](0), splits, bins, 10, unorderedFeatures) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 From c1565a5248e5d0ccc2293315799281030a74c217 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 11 Aug 2014 11:09:32 -0700 Subject: [PATCH 05/34] Small DecisionTree updates: * Simplification: Updated calculateGainForSplit to take aggregates for a single (feature, split) pair. * Internal doc: findAggForOrderedFeatureClassification --- .../spark/mllib/tree/DecisionTree.scala | 309 +++--------------- 1 file changed, 38 insertions(+), 271 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index be57ae7c91832..4ac9ce67c5c47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -55,7 +55,6 @@ class TimeTracker { var extractNodeInfoTime: Long = 0 var extractInfoForLowerLevelsTime: Long = 0 var findBestSplitsTime: Long = 0 - var findBinsForLevelTime: Long = 0 var binAggregatesTime: Long = 0 var chooseSplitsTime: Long = 0 @@ -66,7 +65,6 @@ class TimeTracker { s"extractNodeInfoTime: $extractNodeInfoTime\n" + s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" + s"findBestSplitsTime: $findBestSplitsTime\n" + - s"findBinsForLevelTime: $findBinsForLevelTime\n" + s"binAggregatesTime: $binAggregatesTime\n" + s"chooseSplitsTime: $chooseSplitsTime\n" } @@ -624,68 +622,6 @@ object DecisionTree extends Serializable with Logging { // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ - /* - def findParentFilters(nodeIndex: Int): List[Filter] = { - if (level == 0) { - List[Filter]() - } else { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - filters(nodeFilterIndex) - } - } - */ - - /** - * Find whether the sample is valid input for the current node, i.e., whether it passes through - * all the filters for the current node. - */ - /* - def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = { - // leaf - if ((level > 0) && (parentFilters.length == 0)) { - return false - } - - // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val featureIndex = filter.split.feature - val comparison = filter.comparison - val isFeatureContinuous = filter.split.featureType == Continuous - if (isFeatureContinuous) { - val binId = treePoint.features(featureIndex) - val bin = bins(featureIndex)(binId) - val featureValue = bin.highSplit.threshold - val threshold = filter.split.threshold - comparison match { - case -1 => if (featureValue > threshold) return false - case 1 => if (featureValue <= threshold) return false - } - } else { - val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - val featureValue = if (isUnorderedFeature) { - treePoint.features(featureIndex) - } else { - val binId = treePoint.features(featureIndex) - bins(featureIndex)(binId).category - } - val containsFeature = filter.split.categories.contains(featureValue) - comparison match { - case -1 => if (!containsFeature) return false - case 1 => if (containsFeature) return false - } - } - } - - // Return true when the sample is valid for all filters. - true - } - */ - /** * Get the node index corresponding to this data point. * This is used during training, mimicking prediction. @@ -758,61 +694,6 @@ object DecisionTree extends Serializable with Logging { } } - // TODO: REMOVED findBin() - - /** - * Finds bins for all nodes (and all features) at a given level. - * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and the categorical feature value in multiclass classification. - * Invalid sample is denoted by noting bin for feature 1 as -1. - * - * For unordered features, the "bin index" returned is actually the feature value (category). - * - * @return Array of size 1 + numFeatures * numNodes, where - * arr(0) = label for labeledPoint, and - * arr(1 + numFeatures * nodeIndex + featureIndex) = - * bin index for this labeledPoint - * (or InvalidBinIndex if labeledPoint is not handled by this node) - */ - /* - def findBinsForLevel(treePoint: TreePoint): Array[Double] = { - // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) - // First element of the array is the label of the instance. - arr(0) = treePoint.label - // Iterate over nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, treePoint) - //println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}") - val shift = 1 + numFeatures * nodeIndex - if (!sampleValid) { - // Mark one bin as -1 is sufficient. - arr(shift) = InvalidBinIndex - } else { - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(shift + featureIndex) = treePoint.features(featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 - } - arr - } - */ - - timer.reset() - - // Find feature bins for all nodes at a level. - //val binMappedRDD = input.map(x => findBinsForLevel(x)) - - timer.findBinsForLevelTime += timer.elapsed() - /** * Increment aggregate in location for (node, feature, bin, label). * @@ -828,19 +709,12 @@ object DecisionTree extends Serializable with Logging { agg: Array[Double], nodeIndex: Int, featureIndex: Int): Unit = { - // Find the bin index for this feature. - //val arrShift = 1 + numFeatures * nodeIndex - //val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * treePoint.features(featureIndex) + //numClasses * arr(arrIndex).toInt + + numClasses * treePoint.features(featureIndex) + treePoint.label.toInt - if (aggIndex < 0 || aggIndex >= agg.size) { - val binIndex = treePoint.features(featureIndex) - println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures, level = $level") - } agg(aggIndex) += 1 } @@ -864,9 +738,6 @@ object DecisionTree extends Serializable with Logging { agg: Array[Double], rightChildShift: Int): Unit = { //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") - // Find the bin index for this feature. - //val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - //val featureValue = arr(arrIndex).toInt val featureValue = treePoint.features(featureIndex) // Update the left or right count for one bin. val aggShift = @@ -907,26 +778,6 @@ object DecisionTree extends Serializable with Logging { updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) featureIndex += 1 } - /* - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 - } - */ } val rightChildShift = numClasses * numBins * numFeatures * numNodes @@ -957,47 +808,6 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - /* - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (unorderedFeatures.contains(featureIndex)) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - //------ - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - //------ - featureIndex += 1 - } - } - nodeIndex += 1 - } - */ } /** @@ -1031,33 +841,6 @@ object DecisionTree extends Serializable with Logging { agg(aggIndex + 2) = agg(aggIndex + 2) + label * label featureIndex += 1 } - /* - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label * label - featureIndex += 1 - } - } - nodeIndex += 1 - } - */ } /** @@ -1149,26 +932,20 @@ object DecisionTree extends Serializable with Logging { */ /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], + leftNodeAgg: Array[Double], + rightNodeAgg: Array[Double], topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) - val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - val leftTotalCount = leftCounts.sum - val rightTotalCount = rightCounts.sum + val leftTotalCount = leftNodeAgg.sum + val rightTotalCount = rightNodeAgg.sum val impurity = { if (level > 0) { @@ -1178,7 +955,7 @@ object DecisionTree extends Serializable with Logging { val rootNodeCounts = new Array[Double](numClasses) var classIndex = 0 while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) classIndex += 1 } strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) @@ -1193,8 +970,8 @@ object DecisionTree extends Serializable with Logging { } // Sum of count for each label - val leftRightCounts: Array[Double] = - leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => leftCount + rightCount } @@ -1214,21 +991,18 @@ object DecisionTree extends Serializable with Logging { result._1 } - val predict = indexOfLargestArrayElement(leftRightCounts) - if (predict == 0 && featureIndex == 0 && splitIndex == 0) { - //println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}") - } - val prob = leftRightCounts(predict) / totalCount + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount val leftImpurity = if (leftTotalCount == 0) { topImpurity } else { - strategy.impurity.calculate(leftCounts, leftTotalCount) + strategy.impurity.calculate(leftNodeAgg, leftTotalCount) } val rightImpurity = if (rightTotalCount == 0) { topImpurity } else { - strategy.impurity.calculate(rightCounts, rightTotalCount) + strategy.impurity.calculate(rightNodeAgg, rightTotalCount) } val leftWeight = leftTotalCount / totalCount @@ -1239,13 +1013,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + val leftSumSquares = leftNodeAgg(2) - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + val rightSumSquares = rightNodeAgg(2) val impurity = { if (level > 0) { @@ -1306,6 +1080,20 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + /** + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + * TODO: We could avoid doing one of these cumulative sums. + */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1355,7 +1143,6 @@ object DecisionTree extends Serializable with Logging { val rightChildShift = numClasses * numBins * numFeatures var splitIndex = 0 - var TMPDEBUG = 0.0 while (splitIndex < numBins - 1) { var classIndex = 0 while (classIndex < numClasses) { @@ -1365,12 +1152,10 @@ object DecisionTree extends Serializable with Logging { val rightBinValue = binData(rightChildShift + shift + classIndex) leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - TMPDEBUG += leftBinValue + rightBinValue classIndex += 1 } splitIndex += 1 } - //println(s"found Agg: $TMPDEBUG") } def findAggForRegression( @@ -1425,25 +1210,6 @@ object DecisionTree extends Serializable with Logging { } else { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } - /* - if (isMulticlassClassificationWithCategoricalFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - */ featureIndex += 1 } @@ -1474,8 +1240,9 @@ object DecisionTree extends Serializable with Logging { for (featureIndex <- 0 until numFeatures) { val numSplitsForFeature = getNumSplitsForFeature(featureIndex) for (splitIndex <- 0 until numSplitsForFeature) { - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) } } gains From fd653725dff2ad1de2aaf7eac0b06bbeee8d1129 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 12 Aug 2014 21:13:09 -0700 Subject: [PATCH 06/34] Major changes: * Created ImpurityAggregator classes, rather than old aggregates. * Feature split/bin semantics are based on ordered vs. unordered ** E.g.: numSplits = numBins for all unordered features, and numSplits = numBins - 1 for all ordered features. * numBins can differ for each feature DecisionTree * Major changes based on new aggregator setup ** For ordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(binIndex). ** For unordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(2 * binIndex), * Added LearningMetadata class * Eliminated now-unused functions: ** extractNodeInfo ** getNumSplitsForFeature ** getBinDataForNode (Eliminated since it merely slices/reshapes data.) ImpurityAggregator classes * Changed main aggregate operation to create binAggregates (binSeqOp, binCompOp) to use the aggregator. * Before, for unordered features, the left/right bins were treated as a separate dimension for aggregates. They are now part of the bins: binAggregates is of size: (numNodes, numBins_f, numFeatures) where numBins_f is: ** 2 * [pow(2, maxFeatureValue - 1) - 1] for unordered categorical features ** maxFeatureValue for ordered categorical features ** maxBins for continuous features DecisionTreeSuite * For tests using unordered (low-arity) features, removed checks of Bin.category, which only has meaning for ordered features. --- .../spark/mllib/tree/DecisionTree.scala | 1304 ++++++++--------- .../spark/mllib/tree/impl/TreePoint.scala | 31 +- .../spark/mllib/tree/impurity/Entropy.scala | 51 + .../spark/mllib/tree/impurity/Gini.scala | 54 + .../spark/mllib/tree/impurity/Impurity.scala | 49 + .../spark/mllib/tree/impurity/Variance.scala | 36 + .../spark/mllib/tree/DecisionTreeSuite.scala | 338 +++-- 7 files changed, 975 insertions(+), 888 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4ac9ce67c5c47..8e4bb40a4b3bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,8 +19,6 @@ package org.apache.spark.mllib.tree import java.util.Calendar -import org.apache.spark.mllib.linalg.Vector - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +26,17 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl.TreePoint -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -class TimeTracker { +private[tree] class TimeTracker { var tmpTime: Long = Calendar.getInstance().getTimeInMillis @@ -70,6 +68,101 @@ class TimeTracker { } } + +/** + * Categorical feature metadata. + * + * TODO: Add doc about ordered vs. unordered features. + * Ensure numBins is always greater than the categories. For multiclass classification, + * numBins should be greater than math.pow(2, maxCategories - 1) - 1. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + * + * This needs to be checked here instead of in Strategy since numBins can be determined + * by the number of training examples. + * TODO: Allow this case, where we simply will know nothing about some categories. + * + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + */ +private[tree] class LearningMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val numBins: Array[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) + } else { + numBins(featureIndex) - 1 + } + +} + +private[tree] object LearningMetadata { + + def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): LearningMetadata = { + + val numFeatures = input.take(1)(0).features.size + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClassesForClassification + case Regression => 0 + } + + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + + val unorderedFeatures = new mutable.HashSet[Int]() + // numBins[featureIndex] = number of bins for feature + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) + if (numClasses > 2) { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + val numUnorderedBins = DecisionTree.numUnorderedBins(k) + if (numUnorderedBins < maxPossibleBins) { + numBins(f) = numUnorderedBins + unorderedFeatures.add(f) + } else { + // TODO: Check the below k <= maxBins. + // This used to be k < maxPossibleBins, but <= should work. + // However, there may have been a 1-off error later on allocating 1 extra + // (unused) bin. + require(k <= maxPossibleBins, "numBins should be greater than max categories " + + "in categorical features") + numBins(f) = k + } + } + } else { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + require(k <= maxPossibleBins, "numBins should be greater than max categories " + + "in categorical features") + numBins(f) = k + } + } + + new LearningMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy) + } +} + + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -97,25 +190,43 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) + val metadata = LearningMetadata.buildMetadata(retaggedInput, strategy) + logDebug("maxBins = " + metadata.maxBins) + timer.initTime += timer.elapsed() timer.reset() // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(retaggedInput, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) val numBins = bins(0).length logDebug("numBins = " + numBins) timer.findSplitsBinsTime += timer.elapsed() + /* + println(s"splits:") + for (f <- Range(0, splits.size)) { + for (s <- Range(0, splits(f).size)) { + println(s" splits($f)($s): ${splits(f)(s)}") + } + } + println(s"bins:") + for (f <- Range(0, bins.size)) { + for (s <- Range(0, bins(f).size)) { + println(s" bins($f)($s): ${bins(f)(s)}") + } + } + */ + timer.reset() - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) timer.initTime += timer.elapsed() // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 + val maxNumNodes = DecisionTree.maxNodesInLevel(maxDepth + 1) - 1 // Initialize an array to hold filters applied to points for each node. //val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -127,13 +238,14 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? nodesInTree(0) = true // num features - val numFeatures = retaggedInput.take(1)(0).features.size + val numFeatures = metadata.numFeatures // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + // TODO: Calculate numElementsPerNode in metadata (more precisely) val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, strategy.algo) @@ -155,14 +267,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * still survived the filters of the parent nodes. */ - var findBestSplitsTime: Long = 0 - var extractNodeInfoTime: Long = 0 - var extractInfoForLowerLevelsTime: Long = 0 - var level = 0 var break = false while (level <= maxDepth && !break) { + //println(s"LEVEL $level") logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") @@ -170,12 +279,18 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.reset() - val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - strategy, level, nodes, splits, bins, maxLevelForSingleGroup, unorderedFeatures, timer) + val splitsStatsForLevel: Array[(Split, InformationGainStats)] = + DecisionTree.findBestSplits(treeInput, parentImpurities, + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.findBestSplitsTime += timer.elapsed() - val levelNodeIndexOffset = math.pow(2, level).toInt - 1 + val levelNodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + /* + println(s"splitsStatsForLevel: index=$index") + println(s"\t split: ${nodeSplitStats._1}") + println(s"\t gain stats: ${nodeSplitStats._2}") + */ val nodeIndex = levelNodeIndexOffset + index val isLeftChild = level != 0 && nodeIndex % 2 == 1 val parentNodeIndex = if (isLeftChild) { // -1 for root node @@ -187,7 +302,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // TODO: Use above check to skip unused branch of tree // Extract info for this node (index) at the current level. timer.reset() - extractNodeInfo(nodeSplitStats, level, index, nodes) + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node + timer.extractNodeInfoTime += timer.elapsed() if (level != 0) { // Set parent. @@ -203,9 +324,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.extractInfoForLowerLevelsTime += timer.elapsed() logDebug("final best split = " + nodeSplitStats._1) } - require(math.pow(2, level) == splitsStatsForLevel.length) + require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. - println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") + //println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) { @@ -229,23 +350,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } - /** - * Extract the decision tree node information for the given tree level and node index - */ - private def extractNodeInfo( - nodeSplitStats: (Split, InformationGainStats), - level: Int, - index: Int, - nodes: Array[Node]): Unit = { - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val nodeIndex = math.pow(2, level).toInt - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - } - /** * Extract the decision tree node information for the children of the node */ @@ -257,12 +361,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo parentImpurities: Array[Double]): Unit = { if (level >= maxDepth) return - //filters: Array[List[Filter]]): Unit = { + // TODO: Move nodeIndexOffset calc out of function? + val nodeIndexOffset = DecisionTree.maxNodesInLevel(level + 1) - 1 // 0 corresponds to the left child node and 1 corresponds to the right child node. var i = 0 while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i + // Calculate the index of the node from the node level and the index at the current level. + val nodeIndex = nodeIndexOffset + 2 * index + i val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity } else { @@ -271,15 +376,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) // noting the parent impurities parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - /* - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } - */ i += 1 } } @@ -497,8 +593,8 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @param unorderedFeatures Set of unordered (categorical) features. * @return array (over nodes) of splits with best split for each node at a given level. @@ -507,13 +603,12 @@ object DecisionTree extends Serializable with Logging { protected[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: LearningMetadata, level: Int, nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - unorderedFeatures: Set[Int], timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation //println(s"findBestSplits: level = $level") @@ -528,19 +623,18 @@ object DecisionTree extends Serializable with Logging { // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - nodes, splits, bins, unorderedFeatures, timer, numGroups, groupIndex) + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, + nodes, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, nodes, splits, bins, - unorderedFeatures, timer) + findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] @@ -548,8 +642,8 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for constructing the DecisionTree * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features, indexed as (numFeatures)(numBins) + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of unordered (categorical) features. * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. @@ -559,12 +653,11 @@ object DecisionTree extends Serializable with Logging { private def findBestSplitsPerGroup( input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: LearningMetadata, level: Int, nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], - unorderedFeatures: Set[Int], timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -597,7 +690,7 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = math.pow(2, level).toInt / numGroups + val numNodes = DecisionTree.maxNodesInLevel(level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. @@ -608,16 +701,14 @@ object DecisionTree extends Serializable with Logging { val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification + val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures + logDebug("isMulticlassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -642,7 +733,7 @@ object DecisionTree extends Serializable with Logging { featureValueUpperBound <= node.split.get.threshold } case Categorical => { - val featureValue = if (unorderedFeatures.contains(featureIndex)) { + val featureValue = if (metadata.isUnordered(featureIndex)) { features(featureIndex) } else { val binIndex = features(featureIndex) @@ -678,7 +769,7 @@ object DecisionTree extends Serializable with Logging { } // Used for treePointToNodeIndex - val levelOffset = (math.pow(2, level) - 1).toInt + val levelOffset = DecisionTree.maxNodesInLevel(level) - 1 /** * Find the node (indexed from 0 at the start of this level) for the given example. @@ -694,92 +785,6 @@ object DecisionTree extends Serializable with Logging { } } - /** - * Increment aggregate in location for (node, feature, bin, label). - * - * @param treePoint Data point being aggregated. - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * TODO: UPDATE DOC - */ - def updateBinForOrderedFeature( - treePoint: TreePoint, - agg: Array[Double], - nodeIndex: Int, - featureIndex: Int): Unit = { - // Update the left or right count for one bin. - val aggIndex = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - numClasses * treePoint.features(featureIndex) + - treePoint.label.toInt - agg(aggIndex) += 1 - } - - /** - * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), - * where [bins] ranges over all bins. - * Updates left or right side of aggregate depending on split. - * - * @param treePoint Data point being aggregated. - * @param agg Indexed by (left/right, node, feature, bin, label) - * where label is the least significant bit. - * The left/right specifier is a 0/1 index indicating left/right child info. - * @param rightChildShift Offset for right side of agg. - * TODO: UPDATE DOC - * TODO: Make arg order same as for ordered feature. - */ - def updateBinForUnorderedFeature( - nodeIndex: Int, - featureIndex: Int, - treePoint: TreePoint, - agg: Array[Double], - rightChildShift: Int): Unit = { - //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") - val featureValue = treePoint.features(featureIndex) - // Update the left or right count for one bin. - val aggShift = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - treePoint.label.toInt - // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - val aggIndex = aggShift + binIndex * numClasses - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(aggIndex) += 1 - } else { - agg(rightChildShift + aggIndex) += 1 - } - binIndex += 1 - } - } - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def binaryOrNotCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - featureIndex += 1 - } - } - val rightChildShift = numClasses * numBins * numFeatures * numNodes /** @@ -794,23 +799,40 @@ object DecisionTree extends Serializable with Logging { * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def multiclassWithCategoricalBinSeqOp( - agg: Array[Double], + agg: Array[Array[Array[ImpurityAggregator]]], treePoint: TreePoint, nodeIndex: Int): Unit = { val label = treePoint.label // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - if (unorderedFeatures.contains(featureIndex)) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + if (metadata.isUnordered(featureIndex)) { + // Unordered feature + val featureValue = treePoint.features(featureIndex) + // Update the left or right count for one bin. + // Find all matching bins and increment their values. + val numCategoricalBins = metadata.numBins(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) + } else { + agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label) + } + binIndex += 1 + } } else { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + // Ordered feature + val binIndex = treePoint.features(featureIndex) + agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) } featureIndex += 1 } } /** + * Helper for binSeqOp: for regression and for classification with only ordered features. + * * Performs a sequential aggregation over a partition for regression. * For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. @@ -821,24 +843,21 @@ object DecisionTree extends Serializable with Logging { * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { - // TODO: Move stuff outside loop. + def orderedBinSeqOp( + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint, + nodeIndex: Int): Unit = { val label = treePoint.label // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // Update count, sum, and sum^2 for one bin. val binIndex = treePoint.features(featureIndex) - val aggIndex = - 3 * numBins * numFeatures * nodeIndex + - 3 * numBins * featureIndex + - 3 * binIndex - if (aggIndex >= agg.size) { - println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures") + if (binIndex >= agg(nodeIndex)(featureIndex).size) { + throw new RuntimeException( + s"binIndex: $binIndex, agg(nodeIndex)(featureIndex).size = ${agg(nodeIndex)(featureIndex).size}") } - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label * label + agg(nodeIndex)(featureIndex)(binIndex).add(label) featureIndex += 1 } } @@ -854,527 +873,358 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size for classification: - * numClasses * numBins * numFeatures * numNodes for ordered features, or - * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. + * Ordered features: numNodes * numFeatures * numBins. + * Unordered features: (2 * numNodes) * numFeatures * numBins. * Size for regression: - * 3 * numBins * numFeatures * numNodes. + * numNodes * numFeatures * numBins. * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + def binSeqOp( + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { val nodeIndex = treePointToNodeIndex(treePoint) if (nodeIndex >= 0) { // Otherwise, example does not reach this level. - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } else { - binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } - case Regression => regressionBinSeqOp(agg, treePoint, nodeIndex) + if (isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) + } else { + orderedBinSeqOp(agg, treePoint, nodeIndex) } } agg } - // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) - logDebug("binAggregateLength = " + binAggregateLength) - /** * Combines the aggregates from partitions. * @param agg1 Array containing aggregates from one or more partitions * @param agg2 Array containing aggregates from one or more partitions * @return Combined aggregate from agg1 and agg2 */ - def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { - var index = 0 - val combinedAggregate = new Array[Double](binAggregateLength) - while (index < binAggregateLength) { - combinedAggregate(index) = agg1(index) + agg2(index) - index += 1 + def binCombOp( + agg1: Array[Array[Array[ImpurityAggregator]]], + agg2: Array[Array[Array[ImpurityAggregator]]]): Array[Array[Array[ImpurityAggregator]]] = { + var n = 0 + while (n < agg2.size) { + var f = 0 + while (f < agg2(n).size) { + var b = 0 + while (b < agg2(n)(f).size) { + agg1(n)(f)(b).merge(agg2(n)(f)(b)) + b += 1 + } + f += 1 + } + n += 1 } - combinedAggregate + agg1 } timer.reset() - // Calculate bin aggregates. val binAggregates = { - input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) + val initAgg = getEmptyBinAggregates(metadata, numNodes) + input.aggregate(initAgg)(binSeqOp, binCombOp) } - logDebug("binAggregates.length = " + binAggregates.length) - timer.binAggregatesTime += timer.elapsed() - //2 * numClasses * numBins * numFeatures * numNodes for unordered features. - // (left/right, node, feature, bin, label) /* - println(s"binAggregates:") - for (i <- Range(0,2)) { - for (n <- Range(0,numNodes)) { - for (f <- Range(0,numFeatures)) { - for (b <- Range(0,4)) { - for (c <- Range(0,numClasses)) { - val idx = i * numClasses * numBins * numFeatures * numNodes + - n * numClasses * numBins * numFeatures + - f * numBins * numFeatures + - b * numFeatures + - c - if (binAggregates(idx) != 0) { - println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}") - } - } - } + println("binAggregates:") + for (n <- Range(0, binAggregates.size)) { + for (f <- Range(0, binAggregates(n).size)) { + for (b <- Range(0, binAggregates(n)(f).size)) { + println(s" ($n, $f, $b): ${binAggregates(n)(f)(b)}") } } } */ - /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftNodeAgg left node aggregates for this (feature, split) - * @param rightNodeAgg right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Double], - rightNodeAgg: Array[Double], - topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val leftTotalCount = leftNodeAgg.sum - val rightTotalCount = rightNodeAgg.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } - - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } + timer.binAggregatesTime += timer.elapsed() - // Sum of count for each label - val leftrightNodeAgg: Array[Double] = - leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => - leftCount + rightCount - } + timer.reset() - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 - } + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + val nodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 + // Iterating over all nodes at this level + var nodeIndex = 0 + while (nodeIndex < numNodes) { + //println(s" HANDLING node $nodeIndex") + val nodeImpurityIndex = nodeIndexOffset + nodeIndex + groupShift + //val binsForNode: Array[Double] = getBinDataForNode(node) + //logDebug("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + logDebug("parent node impurity = " + parentNodeImpurity) - val predict = indexOfLargestArrayElement(leftrightNodeAgg) - val prob = leftrightNodeAgg(predict) / totalCount + val (bestFeatureIndex, bestSplitIndex, bestGain) = + binsToBestSplit(binAggregates(nodeIndex), parentNodeImpurity, level, metadata) + bestSplits(nodeIndex) = (splits(bestFeatureIndex)(bestSplitIndex), bestGain) + logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + //println(s"bestSplits(node:$node): ${bestSplits(node)}") - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(leftNodeAgg, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(rightNodeAgg, rightTotalCount) - } + nodeIndex += 1 + } + timer.chooseSplitsTime += timer.elapsed() - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount + bestSplits + } - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + /** + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + def calculateGainForSplit( + leftNodeAgg: ImpurityAggregator, + rightNodeAgg: ImpurityAggregator, + topImpurity: Double, + level: Int, + metadata: LearningMetadata): InformationGainStats = { - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + val leftCount = leftNodeAgg.count + val rightCount = rightNodeAgg.count - case Regression => - val leftCount = leftNodeAgg(0) - val leftSum = leftNodeAgg(1) - val leftSumSquares = leftNodeAgg(2) + val totalCount = leftCount + rightCount + if (totalCount == 0) { + // Return arbitrary prediction. + //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - val rightCount = rightNodeAgg(0) - val rightSum = rightNodeAgg(1) - val rightSumSquares = rightNodeAgg(2) + val parentNodeAgg = leftNodeAgg.copy + parentNodeAgg.merge(rightNodeAgg) + // impurity of parent node + val impurity = if (level > 0) { + topImpurity + } else { + parentNodeAgg.calculate() + } - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } + val predict = parentNodeAgg.predict + val prob = parentNodeAgg.prob(predict) - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) - } + val leftImpurity = leftNodeAgg.calculate() // Note: 0 if count = 0 + val rightImpurity = rightNodeAgg.calculate() + /* + println(s"calculateGainForSplit") + println(s"\t leftImpurity = $leftImpurity, leftNodeAgg: $leftNodeAgg") + println(s"\t rightImpurity = $rightImpurity, rightNodeAgg: $rightNodeAgg") + */ - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + } - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + /** + * Calculates information gain for all nodes splits. + * @param leftNodeAgg Aggregate stats, of dimensions (numFeatures, numSplits(feature)) + * @param rightNodeAgg Aggregate stats, of dimensions (numFeatures, numSplits(feature)) + * @param nodeImpurity Impurity for node being split. + * @return Info gain, of dimensions (numFeatures, numSplits(feature)) + */ + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + nodeImpurity: Double, + level: Int, + metadata: LearningMetadata): Array[Array[InformationGainStats]] = { + val gains = new Array[Array[InformationGainStats]](metadata.numFeatures) + + for (featureIndex <- 0 until metadata.numFeatures) { + val numSplitsForFeature = metadata.numSplits(featureIndex) + gains(featureIndex) = new Array[InformationGainStats](numSplitsForFeature) + for (splitIndex <- 0 until numSplitsForFeature) { + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity, level, metadata) } } + gains + } + + /** + * Extracts left and right split aggregates. + * @param binData Aggregate array slice from getBinDataForNode. + * For classification: + * For unordered features, this is leftChildData ++ rightChildData, + * each of which is indexed by (feature, split/bin, class), + * with class being the least significant bit. + * For ordered features, this is of size numClasses * numBins * numFeatures. + * For regression: + * This is of size 2 * numFeatures * numBins. + * @return (leftNodeAgg, rightNodeAgg) pair of arrays. + * Each array is of size (numFeatures, numSplits(feature)). + * TODO: Extract in-place. + */ + def extractLeftRightNodeAggregates( + nodeAggregates: Array[Array[ImpurityAggregator]], + metadata: LearningMetadata): (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { + + val numClasses = metadata.numClasses + val numFeatures = metadata.numFeatures /** - * Extracts left and right split aggregates. - * @param binData Aggregate array slice from getBinDataForNode. - * For classification: - * For unordered features, this is leftChildData ++ rightChildData, - * each of which is indexed by (feature, split/bin, class), - * with class being the least significant bit. - * For ordered features, this is of size numClasses * numBins * numFeatures. - * For regression: - * This is of size 2 * numFeatures * numBins. - * @return (leftNodeAgg, rightNodeAgg) pair of arrays. - * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). - * For regression, each array is of size (numFeatures, (numBins - 1), 3). - * + * Reshape binData for this feature. + * Indexes binData as (feature, split, class) with class as the least significant bit. + * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value */ - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - - - /** - * The input binData is indexed as (feature, bin, class). - * This computes cumulative sums over splits. - * Each (feature, class) pair is handled separately. - * Note: numSplits = numBins - 1. - * @param leftNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 0, ..., numSplits - 2) is set to be - * the cumulative sum (from left) over binData for bins 0, ..., i. - * @param rightNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 1, ..., numSplits - 1) is set to be - * the cumulative sum (from right) over binData for bins - * numBins - 1, ..., numBins - 1 - i. - * TODO: We could avoid doing one of these cumulative sums. - */ - def findAggForOrderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - - var classIndex = 0 - while (classIndex < numClasses) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (numClasses * (numBins - 1)) + classIndex) - classIndex += 1 - } + def findAggForUnorderedFeature( + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { + // TODO: Don't pass in featureIndex; use index before call. + // Note: numBins = numSplits for unordered features. + val numBins = metadata.numBins(featureIndex) + leftNodeAgg(featureIndex) = binData(featureIndex).slice(0, numBins) + rightNodeAgg(featureIndex) = binData(featureIndex).slice(numBins, 2 * numBins) + } - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - var innerClassIndex = 0 - while (innerClassIndex < numClasses) { - leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + numClasses * splitIndex + innerClassIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) - innerClassIndex += 1 - } - splitIndex += 1 - } - } + /** + * For ordered features (regression and classification with ordered features). + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + * TODO: We could avoid doing one of these cumulative sums. + */ + def findAggForOrderedFeature( + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { - /** - * Reshape binData for this feature. - * Indexes binData as (feature, split, class) with class as the least significant bit. - * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value - */ - def findAggForUnorderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - val rightChildShift = numClasses * numBins * numFeatures - var splitIndex = 0 - while (splitIndex < numBins - 1) { - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins + splitIndex * numClasses - val leftBinValue = binData(shift + classIndex) - val rightBinValue = binData(rightChildShift + shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue - rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - classIndex += 1 - } - splitIndex += 1 - } - } + // TODO: Don't pass in featureIndex; use index before call. - def findAggForRegression( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { + val numSplits = metadata.numSplits(featureIndex) + leftNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) + rightNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) - // shift for this featureIndex - val shift = 3 * featureIndex * numBins + if (metadata.isContinuous(featureIndex)) { // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - + leftNodeAgg(featureIndex)(0) = binData(featureIndex)(0).copy // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) + rightNodeAgg(featureIndex)(numSplits - 1) = binData(featureIndex)(numSplits).copy // Iterate over all splits. var splitIndex = 1 - while (splitIndex < numBins - 1) { - var i = 0 // index for regression histograms - while (i < 3) { // count, sum, sum^2 - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + - leftNodeAgg(featureIndex)(splitIndex - 1)(i) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = - binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) - i += 1 - } + while (splitIndex < numSplits) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex) = leftNodeAgg(featureIndex)(splitIndex - 1).copy + leftNodeAgg(featureIndex)(splitIndex).merge(binData(featureIndex)(splitIndex)) + rightNodeAgg(featureIndex)(numSplits - 1 - splitIndex) = + rightNodeAgg(featureIndex)(numSplits - splitIndex).copy + rightNodeAgg(featureIndex)(numSplits - 1 - splitIndex).merge( + binData(featureIndex)(numSplits - splitIndex)) splitIndex += 1 } - } - - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (unorderedFeatures.contains(featureIndex)) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } - - /** - * Calculates information gain for all nodes splits. - */ - def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - - for (featureIndex <- 0 until numFeatures) { - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - for (splitIndex <- 0 until numSplitsForFeature) { - gains(featureIndex)(splitIndex) = - calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), - rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) + } else { // ordered categorical feature + /* TODO: This is a temp fix. + * Eventually, for ordered categorical features, change splits and bins to be + * for individual categories instead of running totals over a pre-defined category + * ordering. Then, we could choose the ordering in this function, tailoring it + * to this particular node. + */ + var splitIndex = 0 + while (splitIndex < numSplits) { + // no need to clone since no cumulative sum is needed + leftNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex) + rightNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex + 1) + splitIndex += 1 } } - gains } - /** - * Get the number of splits for a feature. - */ - def getNumSplitsForFeature(featureIndex: Int): Int = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { - // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - if (unorderedFeatures.contains(featureIndex)) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { - featureCategories - } - /* - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 + val leftNodeAgg = new Array[Array[ImpurityAggregator]](numFeatures) + val rightNodeAgg = new Array[Array[ImpurityAggregator]](numFeatures) + if (metadata.isClassification) { + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + findAggForUnorderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) } else { - // Ordered features - featureCategories + findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) } - */ + featureIndex += 1 + } + } else { // Regression + var featureIndex = 0 + while (featureIndex < numFeatures) { + findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) + featureIndex += 1 } } + (leftNodeAgg, rightNodeAgg) + } - /** - * Find the best split for a node. - * @param binData Bin data slice for this node, given by getBinDataForNode. - * @param nodeImpurity impurity of the top node - * @return tuple of split and information gain - */ - def binsToBestSplit( - binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { - - logDebug("node impurity = " + nodeImpurity) - - // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - - // Calculate gains for all splits. - val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + /** + * Find the best split for a node. + * @param binData Bin data slice for this node, given by getBinDataForNode. + * @param nodeImpurity impurity of the top node + * @return tuple (best feature index, best split index, information gain) + */ + def binsToBestSplit( + nodeAggregates: Array[Array[ImpurityAggregator]], + nodeImpurity: Double, + level: Int, + metadata: LearningMetadata): (Int, Int, InformationGainStats) = { - val (bestFeatureIndex, bestSplitIndex, gainStats) = { - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - // Iterate over features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Iterate over all splits. - var splitIndex = 0 - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - while (splitIndex < numSplitsForFeature) { - val gainStats = gains(featureIndex)(splitIndex) - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex - //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") - } - splitIndex += 1 - } - featureIndex += 1 - } - (bestFeatureIndex, bestSplitIndex, bestGainStats) + logDebug("node impurity = " + nodeImpurity) + /* + println("nodeAggregates") + for (f <- Range(0, nodeAggregates.size)) { + for (b <- Range(0, nodeAggregates(f).size)) { + println(s"nodeAggregates($f)($b): ${nodeAggregates(f)(b)}") } - - logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - - (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } - - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + */ + // Extract left right node aggregates. + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(nodeAggregates, metadata) + + // Calculate gains for all splits. + val gains = + calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity, level, metadata) + + val (bestFeatureIndex, bestSplitIndex, gainStats) = { + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + // Iterate over all splits. + var splitIndex = 0 + val numSplitsForFeature = metadata.numSplits(featureIndex) + while (splitIndex < numSplitsForFeature) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode + splitIndex += 1 + } + featureIndex += 1 } + (bestFeatureIndex, bestSplitIndex, bestGainStats) } - timer.reset() - - // Calculate best splits for all nodes at a given level - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - // Iterating over all nodes at this level - var node = 0 - while (node < numNodes) { - val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) - logDebug("nodeImpurityIndex = " + nodeImpurityIndex) - val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("parent node impurity = " + parentNodeImpurity) - bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) - //println(s"bestSplits(node:$node): ${bestSplits(node)}") - node += 1 - } - timer.chooseSplitsTime += timer.elapsed() - - bestSplits + (bestFeatureIndex, bestSplitIndex, gainStats) } /** @@ -1399,6 +1249,47 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Get an empty instance of bin aggregates. + * For ordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(binIndex). + * For unordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(2 * binIndex), + * where the bins are ordered as (numBins left bins, numBins right bins). + */ + private def getEmptyBinAggregates( + metadata: LearningMetadata, + numNodes: Int): Array[Array[Array[ImpurityAggregator]]] = { + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + val binMultiplier = if (metadata.isMulticlassWithCategoricalFeatures) { + 2 + } else { + 1 + } + val agg = Array.fill[Array[ImpurityAggregator]](numNodes, metadata.numFeatures)( + new Array[ImpurityAggregator](0)) + var nodeIndex = 0 + while (nodeIndex < numNodes) { + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + var binIndex = 0 + val effNumBins = metadata.numBins(featureIndex) * binMultiplier + agg(nodeIndex)(featureIndex) = new Array[ImpurityAggregator](effNumBins) + while (binIndex < effNumBins) { + agg(nodeIndex)(featureIndex)(binIndex) = impurityAggregator.newAggregator + binIndex += 1 + } + featureIndex += 1 + } + nodeIndex += 1 + } + agg + } + /** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. @@ -1423,49 +1314,28 @@ object DecisionTree extends Serializable with Logging { * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for construction the DecisionTree - * @return A tuple of (splits, bins, unorderedFeatures). + * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numBins - 1). + * of size (numFeatures, numSplits). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] * of size (numFeatures, numBins). - * unorderedFeatures: set of indices for unordered features. */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]], Set[Int]) = { - - val count = input.count() - - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size + metadata: LearningMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - val maxBins = strategy.maxBins - val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - - - /* - * Ensure numBins is always greater than the categories. For multiclass classification, - * numBins should be greater than 2^(maxCategories - 1) - 1. - * It's a limitation of the current implementation but a reasonable trade-off since features - * with large number of categories get favored over continuous features. - * - * This needs to be checked here instead of in Strategy since numBins can be determined - * by the number of training examples. - * TODO: Allow this case, where we simply will know nothing about some categories. - */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + - "in categorical features") - } + val isMulticlassClassification = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlassClassification) + val numFeatures = metadata.numFeatures // Calculate the number of sample for approximate quantile calculation. - val requiredSamples = numBins*numBins - val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } logDebug("fraction of data used for calculating quantiles = " + fraction) // sampled input for RDD calculation @@ -1473,55 +1343,57 @@ object DecisionTree extends Serializable with Logging { input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length - val stride: Double = numSamples.toDouble / numBins - logDebug("stride = " + stride) - - strategy.quantileCalculationStrategy match { + metadata.quantileStrategy match { case Sort => - val splits = Array.ofDim[Split](numFeatures, numBins - 1) - val bins = Array.ofDim[Bin](numFeatures, numBins) + val splits = new Array[Array[Split]](numFeatures) + val bins = new Array[Array[Bin]](numFeatures) + var i = 0 + while (i < numFeatures) { + splits(i) = new Array[Split](metadata.numSplits(i)) + bins(i) = new Array[Bin](metadata.numBins(i)) + i += 1 + } // Find all splits. - // Record which categorical features will be ordered vs. unordered. - val unorderedFeatures = new mutable.HashSet[Int]() - // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { + val numSplits = metadata.numSplits(featureIndex) + if (metadata.isContinuous(featureIndex)) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / numBins + val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) logDebug("stride = " + stride) - for (index <- 0 until numBins - 1) { - val sampleIndex = index * stride.toInt + for (splitIndex <- 0 until numSplits) { + val sampleIndex = splitIndex * stride.toInt // Set threshold halfway in between 2 samples. val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 - val split = new Split(featureIndex, threshold, Continuous, List()) - splits(featureIndex)(index) = split + splits(featureIndex)(splitIndex) = + new Split(featureIndex, threshold, Continuous, List()) } - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - - // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint. - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - unorderedFeatures.add(featureIndex) + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (splitIndex <- 1 until numSplits) { + bins(featureIndex)(splitIndex) = + new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), + Continuous, Double.MinValue) + } + bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } else { + // Categorical feature + val featureArity = metadata.featureArity(featureIndex) + if (metadata.isUnordered(featureIndex)) { + // Unordered features: low-arity features in multiclass classification // 2^(maxFeatureValue- 1) - 1 combinations - var index = 0 - while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { - val categories: List[Double] - = extractMultiClassCategories(index + 1, featureCategories) - splits(featureIndex)(index) - = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(index) = { - if (index == 0) { + var splitIndex = 0 + while (splitIndex < numSplits) { + val categories: List[Double] = + extractMultiClassCategories(splitIndex + 1, featureArity) + splits(featureIndex)(splitIndex) = + new Split(featureIndex, Double.MinValue, Categorical, categories) + bins(featureIndex)(splitIndex) = { + if (splitIndex == 0) { new Bin( new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), @@ -1529,15 +1401,16 @@ object DecisionTree extends Serializable with Logging { Double.MinValue) } else { new Bin( - splits(featureIndex)(index - 1), - splits(featureIndex)(index), + splits(featureIndex)(splitIndex - 1), + splits(featureIndex)(splitIndex), Categorical, Double.MinValue) } } - index += 1 + splitIndex += 1 } - } else { // ordered feature + } else { + // Ordered features: high-arity features, or not multiclass classification /* For a given categorical feature, use a subsample of the data * to choose how to arrange possible splits. * This examines each category and computes a centroid. @@ -1553,7 +1426,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) + .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1568,7 +1441,7 @@ object DecisionTree extends Serializable with Logging { // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until featureCategories) { + for (i <- 0 until featureArity) { if (centroidForCategories.contains(i)) { fullCentroidForCategories(i) = centroidForCategories(i) } else { @@ -1583,17 +1456,23 @@ object DecisionTree extends Serializable with Logging { var categoriesForSplit = List[Double]() categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, - Categorical, categoriesForSplit) - bins(featureIndex)(index) = { - if (index == 0) { + case ((category, value), binIndex) => + categoriesForSplit = category :: categoriesForSplit + if (binIndex < numSplits) { + splits(featureIndex)(binIndex) = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + } + bins(featureIndex)(binIndex) = { + if (binIndex == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + splits(featureIndex)(0), Categorical, category) + } else if (binIndex == numSplits) { + new Bin(splits(featureIndex)(binIndex - 1), + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit), + Categorical, category) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin(splits(featureIndex)(binIndex - 1), splits(featureIndex)(binIndex), + Categorical, category) } } } @@ -1602,24 +1481,7 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } - // Find all bins. - featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { // Bins for categorical variables are already assigned. - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1) { - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Continuous, Double.MinValue) - bins(featureIndex)(index) = bin - } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } - featureIndex += 1 - } - (splits, bins, unorderedFeatures.toSet) + (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => @@ -1651,4 +1513,12 @@ object DecisionTree extends Serializable with Logging { categories } + private[tree] def maxNodesInLevel(level: Int): Int = { + math.pow(2, level).toInt + } + + private[tree] def numUnorderedBins(arity: Int): Int = { + (math.pow(2, arity - 1) - 1).toInt + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index f3b5dce041207..7ea0071654161 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.impl import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.LearningMetadata import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.model.Bin import org.apache.spark.rdd.RDD @@ -34,39 +35,24 @@ private[tree] object TreePoint { def convertToTreeRDD( input: RDD[LabeledPoint], - strategy: Strategy, - bins: Array[Array[Bin]]): RDD[TreePoint] = { + bins: Array[Array[Bin]], + metadata: LearningMetadata): RDD[TreePoint] = { input.map { x => - TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins, - strategy.categoricalFeaturesInfo) + TreePoint.labeledPointToTreePoint(x, bins, metadata) } } def labeledPointToTreePoint( labeledPoint: LabeledPoint, - isMulticlassClassification: Boolean, bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): TreePoint = { + metadata: LearningMetadata): TreePoint = { val numFeatures = labeledPoint.features.size - val numBins = bins(0).size val arr = new Array[Int](numFeatures) - var featureIndex = 0 // offset by 1 for label + var featureIndex = 0 while (featureIndex < numFeatures) { - val featureInfo = categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false, - bins, categoricalFeaturesInfo) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isUnorderedFeature, bins, categoricalFeaturesInfo) - } + arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), + metadata.isUnordered(featureIndex), bins, metadata.featureArity) featureIndex += 1 } @@ -172,6 +158,7 @@ private[tree] object TreePoint { sequentialBinSearchForOrderedCategoricalFeature() } if (binIndex == -1) { + println(s"findBin: binIndex = -1. isUnorderedFeature = $isUnorderedFeature, featureIndex = $featureIndex, labeledPoint = $labeledPoint") throw new UnknownError("no bin was found for categorical variable.") } binIndex diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 96d2471e1f88c..77d64b69c39c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -74,3 +74,54 @@ object Entropy extends Impurity { def instance = this } + +private[tree] class EntropyAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + def calculate(): Double = { + Entropy.calculate(counts, counts.sum) + } + + def copy: EntropyAggregator = { + val tmp = new EntropyAggregator(counts.size) + tmp.counts = this.counts.clone() + tmp + } + + def add(label: Double): Unit = { + if (label >= counts.size) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s" but requires label < numClasses (= ${counts.size}).") + } + counts(label.toInt) += 1 + } + + def count: Long = counts.sum.toLong + + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(counts) + } + + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < counts.length, + s"EntropyAggregator.prob given invalid label: $lbl (should be < ${counts.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + counts(lbl) / cnt + } + } + + override def toString: String = { + s"EntropyAggregator(counts = [${counts.mkString(", ")}])" + } + + def newAggregator: EntropyAggregator = { + new EntropyAggregator(counts.size) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d586f449048bb..e8aa4e9c7f7c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -70,3 +70,57 @@ object Gini extends Impurity { def instance = this } + +private[tree] class GiniAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + def calculate(): Double = { + Gini.calculate(counts, counts.sum) + } + + def copy: GiniAggregator = { + val tmp = new GiniAggregator(counts.size) + tmp.counts = this.counts.clone() + tmp + } + + def add(label: Double): Unit = { + if (label >= counts.size) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s" but requires label < numClasses (= ${counts.size}).") + } + if (label.toInt >= counts.size) { + throw new RuntimeException(s"label = $label, counts = $counts") + } + counts(label.toInt) += 1 + } + + def count: Long = counts.sum.toLong + + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(counts) + } + + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < counts.length, + s"GiniAggregator.prob given invalid label: $lbl (should be < ${counts.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + counts(lbl) / cnt + } + } + + override def toString: String = { + s"GiniAggregator(counts = [${counts.mkString(", ")}])" + } + + def newAggregator: GiniAggregator = { + new GiniAggregator(counts.size) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 92b0c7b4a6fbc..807207d827137 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -47,3 +47,52 @@ trait Impurity extends Serializable { @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } + + +private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializable { + + var counts: Array[Double] = new Array[Double](statsSize) + + def copy: ImpurityAggregator + + def add(label: Double): Unit + + def calculate(): Double + + def merge(other: ImpurityAggregator): ImpurityAggregator = { + require(counts.size == other.counts.size, + s"Two ImpurityAggregator instances cannot be merged with different counts sizes." + + s" Sizes are ${counts.size} and ${other.counts.size}.") + var i = 0 + while (i < other.counts.size) { + counts(i) += other.counts(i) + i += 1 + } + this + } + + def count: Long + + def newAggregator: ImpurityAggregator + + def predict: Double + + def prob(label: Double): Double = -1 + + protected def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("ImpurityAggregator internal error:" + + " indexOfLargestArrayElement failed") + } + result._1 + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index f7d99a40eb380..c406cd580fc40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -61,3 +61,39 @@ object Variance extends Impurity { def instance = this } + +private[tree] class VarianceAggregator extends ImpurityAggregator(3) with Serializable { + + def calculate(): Double = { + Variance.calculate(counts(0), counts(1), counts(2)) + } + + def copy: VarianceAggregator = { + val tmp = new VarianceAggregator() + tmp.counts = this.counts.clone() + tmp + } + + def add(label: Double): Unit = { + counts(0) += label + counts(1) += label * label + counts(2) += 1 + } + + def count: Long = counts(2).toLong + + def predict: Double = if (count == 0) { + 0 + } else { + counts(0) / counts(2) + } + + override def toString: String = { + s"VarianceAggregator(sum = ${counts(0)}, sum2 = ${counts(1)}, cnt = ${counts(2)})" + } + + def newAggregator: VarianceAggregator = { + new VarianceAggregator() + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9e6429f2ff108..fad840331063b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -62,19 +62,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(mse <= requiredMSE) } - test("split and bin calculation") { + test("split and bin calculation for continuous features") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) assert(bins(0).length === 100) } - test("split and bin calculation for categorical variables") { + test("split and bin calculation for binary features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -85,11 +86,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 99) - assert(bins(0).length === 100) + assert(splits(0).length === 1) + assert(bins(0).length === 2) // Check splits. @@ -97,16 +102,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0)(0).threshold === Double.MinValue) assert(splits(0)(0).featureType === Categorical) assert(splits(0)(0).categories.length === 1) + //println(s"splits(0)(0).categories: ${splits(0)(0).categories}") assert(splits(0)(0).categories.contains(1.0)) + /* assert(splits(0)(1).feature === 0) assert(splits(0)(1).threshold === Double.MinValue) assert(splits(0)(1).featureType === Categorical) assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) - - assert(splits(0)(2) === null) + assert(splits(0)(1).categories.contains(1.0)) + */ assert(splits(1)(0).feature === 1) assert(splits(1)(0).threshold === Double.MinValue) @@ -114,47 +120,38 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) + /* assert(splits(1)(1).feature === 1) assert(splits(1)(1).threshold === Double.MinValue) assert(splits(1)(1).featureType === Categorical) assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) - - assert(splits(1)(2) === null) - + assert(splits(1)(1).categories.contains(1.0)) + */ // Check bins. - assert(bins(0)(0).category === 1.0) assert(bins(0)(0).lowSplit.categories.length === 0) assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category === 0.0) assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).lowSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(0.0)) + assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(2) === null) - - assert(bins(1)(0).category === 0.0) assert(bins(1)(0).lowSplit.categories.length === 0) assert(bins(1)(0).highSplit.categories.length === 1) assert(bins(1)(0).highSplit.categories.contains(0.0)) - assert(bins(1)(1).category === 1.0) assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.length === 2) assert(bins(1)(1).highSplit.categories.contains(0.0)) assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2) === null) } - test("split and bin calculations for categorical variables with no sample for one category") { + test("Binary classification with 3-category features, with no samples for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -165,7 +162,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 2) + assert(bins(0).length === 3) // Check splits. @@ -179,18 +184,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0)(1).threshold === Double.MinValue) assert(splits(0)(1).featureType === Categorical) assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) assert(splits(0)(1).categories.contains(0.0)) + assert(splits(0)(1).categories.contains(1.0)) + /* assert(splits(0)(2).feature === 0) assert(splits(0)(2).threshold === Double.MinValue) assert(splits(0)(2).featureType === Categorical) assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(1.0)) assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(1.0)) assert(splits(0)(2).categories.contains(2.0)) - - assert(splits(0)(3) === null) + */ assert(splits(1)(0).feature === 1) assert(splits(1)(0).threshold === Double.MinValue) @@ -202,66 +207,53 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(1).threshold === Double.MinValue) assert(splits(1)(1).featureType === Categorical) assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) assert(splits(1)(1).categories.contains(0.0)) + assert(splits(1)(1).categories.contains(1.0)) + /* assert(splits(1)(2).feature === 1) assert(splits(1)(2).threshold === Double.MinValue) assert(splits(1)(2).featureType === Categorical) assert(splits(1)(2).categories.length === 3) - assert(splits(1)(2).categories.contains(1.0)) assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(1.0)) assert(splits(1)(2).categories.contains(2.0)) - - assert(splits(1)(3) === null) + */ // Check bins. - assert(bins(0)(0).category === 1.0) assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(0).highSplit.categories === splits(0)(0).categories) - assert(bins(0)(1).category === 0.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(0.0)) + assert(bins(0)(1).lowSplit.categories === splits(0)(0).categories) + assert(bins(0)(1).highSplit.categories === splits(0)(1).categories) - assert(bins(0)(2).category === 2.0) - assert(bins(0)(2).lowSplit.categories.length === 2) + assert(bins(0)(2).lowSplit.categories === splits(0)(1).categories) + /* + assert(bins(0)(2).lowSplit.categories.length === 1) assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).lowSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.length === 3) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.contains(2.0)) + */ - assert(bins(0)(3) === null) + //assert(bins(0)(2).highSplit.categories === splits(0)(2).categories) + assert(bins(0)(2).highSplit.categories === List(2.0, 0.0, 1.0)) - assert(bins(1)(0).category === 0.0) assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) + assert(bins(1)(0).highSplit.categories === splits(1)(0).categories) +/* assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0))*/ - assert(bins(1)(1).category === 1.0) - assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories === splits(1)(0).categories) + assert(bins(1)(1).highSplit.categories === splits(1)(1).categories) +/* assert(bins(1)(1).lowSplit.categories.length === 1) assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2).category === 2.0) - assert(bins(1)(2).lowSplit.categories.length === 2) - assert(bins(1)(2).lowSplit.categories.contains(0.0)) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 3) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(2.0)) + assert(bins(1)(1).highSplit.categories.length === 1) + assert(bins(1)(1).highSplit.categories.contains(1.0))*/ - assert(bins(1)(3) === null) + assert(bins(1)(2).lowSplit.categories === splits(1)(1).categories) + //assert(bins(1)(2).highSplit.categories === splits(1)(2).categories) + assert(bins(1)(2).highSplit.categories === List(2.0, 1.0, 0.0)) +/* assert(bins(1)(2).lowSplit.categories.length === 1) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) */ } test("extract categories from a number for multiclass classification") { @@ -282,7 +274,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 3) + assert(bins(0).length === 3) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -320,10 +320,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - assert(splits(0)(3) === null) - assert(splits(1)(3) === null) - - // Check bins. assert(bins(0)(0).category === Double.MinValue) @@ -359,9 +355,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(3) === null) - assert(bins(1)(3) === null) - } test("split and bin calculations for ordered categorical variables with multiclass " + @@ -376,7 +369,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 9) + assert(bins(0).length === 10) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -399,28 +400,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0)(2).categories.contains(2.0)) assert(splits(0)(2).categories.contains(1.0)) - assert(splits(0)(10) === null) - assert(splits(1)(10) === null) - - // Check bins. assert(bins(0)(0).category === 1.0) assert(bins(0)(0).lowSplit.categories.length === 0) assert(bins(0)(0).highSplit.categories.length === 1) assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(1).category === 2.0) assert(bins(0)(1).lowSplit.categories.length === 1) assert(bins(0)(1).highSplit.categories.length === 2) assert(bins(0)(1).highSplit.categories.contains(1.0)) assert(bins(0)(1).highSplit.categories.contains(2.0)) - - assert(bins(0)(10) === null) - } - test("classification stump with all categorical variables") { + test("classification stump with all ordered categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -431,14 +426,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 2) + assert(bins(0).length === 3) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 - assert(split.categories.length === 1) - assert(split.categories.contains(1.0)) + assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) @@ -459,10 +462,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd,strategy) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -486,6 +493,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateRegressor(model, arr, 0.0) @@ -497,8 +507,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -506,9 +521,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -520,8 +535,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -529,9 +549,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -544,8 +564,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -553,9 +578,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -568,8 +593,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -577,9 +607,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -595,7 +625,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -609,9 +639,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, - splits, bins, 10, unorderedFeatures) + splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) assert(bestSplits(1)._2.gain > 0) @@ -639,14 +669,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -662,11 +696,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) + println(model) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -679,11 +714,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -693,16 +728,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -714,7 +750,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -722,18 +758,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -747,18 +784,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -771,14 +809,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins, unorderedFeatures) = DecisionTree.findSplitsBins(input, strategy) - val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, - new Array[Node](0), splits, bins, 10, unorderedFeatures) + val metadata = LearningMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 From 51ef7813d9fb1c98457f015e1aa7dca92816750a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 13 Aug 2014 01:26:21 -0700 Subject: [PATCH 07/34] Fixed bug introduced by last commit: Variance impurity calculation was incorrect since counts were swapped accidentally --- .../apache/spark/mllib/tree/impurity/Variance.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index c406cd580fc40..63030cc2de5d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -75,21 +75,21 @@ private[tree] class VarianceAggregator extends ImpurityAggregator(3) with Serial } def add(label: Double): Unit = { - counts(0) += label - counts(1) += label * label - counts(2) += 1 + counts(0) += 1 + counts(1) += label + counts(2) += label * label } - def count: Long = counts(2).toLong + def count: Long = counts(0).toLong def predict: Double = if (count == 0) { 0 } else { - counts(0) / counts(2) + counts(1) / count } override def toString: String = { - s"VarianceAggregator(sum = ${counts(0)}, sum2 = ${counts(1)}, cnt = ${counts(2)})" + s"VarianceAggregator(cnt = ${counts(0)}, sum = ${counts(1)}, sum2 = ${counts(2)})" } def newAggregator: VarianceAggregator = { From e3c84ccf06f58fce235fb387c7fd0b432103e5a1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 13 Aug 2014 17:49:25 -0700 Subject: [PATCH 08/34] Added stuff fro mnist8m to D T Runner --- .../org/apache/spark/examples/mllib/DecisionTreeRunner.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cf3d2cca81ff6..b85dfd0b5764b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -156,8 +156,10 @@ object DecisionTreeRunner { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } + println("opt3") // Split into training, test. - val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + val splitsTMP = examples.randomSplit(Array(0.9, 0.1), seed = 1234) + val splits = splitsTMP(1).randomSplit(Array(1.0 - params.fracTest, params.fracTest), seed = 12345) val training = splits(0).cache() val test = splits(1).cache() val numTraining = training.count() From 86e217fb454e2834f92ecfbebd33419c886fe944 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 13 Aug 2014 17:58:52 -0700 Subject: [PATCH 09/34] added cache to DT input --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8e4bb40a4b3bb..fe69a9223db2f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -220,7 +220,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ timer.reset() - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata).cache() timer.initTime += timer.elapsed() // depth of the decision tree From 438a66018775dc928644d32e833aecd6c2265109 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 13 Aug 2014 18:17:16 -0700 Subject: [PATCH 10/34] removed subsampling for mnist8m from DT --- .../org/apache/spark/examples/mllib/DecisionTreeRunner.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b85dfd0b5764b..06771288fda96 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -158,8 +158,7 @@ object DecisionTreeRunner { println("opt3") // Split into training, test. - val splitsTMP = examples.randomSplit(Array(0.9, 0.1), seed = 1234) - val splits = splitsTMP(1).randomSplit(Array(1.0 - params.fracTest, params.fracTest), seed = 12345) + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest), seed = 12345) val training = splits(0).cache() val test = splits(1).cache() val numTraining = training.count() From dd4d3aa65e796d5b6ac36e36c0172cc90ad4ae15 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 11:27:34 -0700 Subject: [PATCH 11/34] Mid-process in bug fix: bug for binary classification with categorical features * Bug: Categorical features were all treated as ordered for binary classification. This is possible but would require the bin ordering to be determined on-the-fly after the aggregation. Currently, the ordering is determined a priori and fixed for all splits. * (Temp) Fix: Treat low-arity categorical features as unordered for binary classification. * Related change: I removed most tests for isMulticlass in the code. I instead test metadata for whether there are unordered features. * Status: The bug may be fixed, but more testing needs to be done. Aggregates: The same binMultiplier (for ordered vs. unordered) was applied to all features. It is now applied on a per-feature basis. --- .../spark/mllib/tree/DecisionTree.scala | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index fe69a9223db2f..30604539f346f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -132,7 +132,7 @@ private[tree] object LearningMetadata { val unorderedFeatures = new mutable.HashSet[Int]() // numBins[featureIndex] = number of bins for feature val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) - if (numClasses > 2) { + if (numClasses >= 2) { strategy.categoricalFeaturesInfo.foreach { case (f, k) => val numUnorderedBins = DecisionTree.numUnorderedBins(k) if (numUnorderedBins < maxPossibleBins) { @@ -204,7 +204,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.findSplitsBinsTime += timer.elapsed() - /* println(s"splits:") for (f <- Range(0, splits.size)) { for (s <- Range(0, splits(f).size)) { @@ -217,7 +216,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo println(s" bins($f)($s): ${bins(f)(s)}") } } - */ timer.reset() val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata).cache() @@ -271,7 +269,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var break = false while (level <= maxDepth && !break) { - //println(s"LEVEL $level") + println(s"LEVEL $level") logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") @@ -286,11 +284,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val levelNodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - /* println(s"splitsStatsForLevel: index=$index") println(s"\t split: ${nodeSplitStats._1}") println(s"\t gain stats: ${nodeSplitStats._2}") - */ val nodeIndex = levelNodeIndexOffset + index val isLeftChild = level != 0 && nodeIndex % 2 == 1 val parentNodeIndex = if (isLeftChild) { // -1 for root node @@ -326,7 +322,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. - //println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") + println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) { @@ -798,7 +794,7 @@ object DecisionTree extends Serializable with Logging { * @param treePoint Data point being aggregated. * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp( + def someUnorderedBinSeqOp( agg: Array[Array[Array[ImpurityAggregator]]], treePoint: TreePoint, nodeIndex: Int): Unit = { @@ -885,10 +881,10 @@ object DecisionTree extends Serializable with Logging { treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { val nodeIndex = treePointToNodeIndex(treePoint) if (nodeIndex >= 0) { // Otherwise, example does not reach this level. - if (isMulticlassWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } else { + if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg, treePoint, nodeIndex) + } else { + someUnorderedBinSeqOp(agg, treePoint, nodeIndex) } } agg @@ -926,7 +922,6 @@ object DecisionTree extends Serializable with Logging { input.aggregate(initAgg)(binSeqOp, binCombOp) } - /* println("binAggregates:") for (n <- Range(0, binAggregates.size)) { for (f <- Range(0, binAggregates(n).size)) { @@ -935,7 +930,6 @@ object DecisionTree extends Serializable with Logging { } } } - */ timer.binAggregatesTime += timer.elapsed() @@ -1006,11 +1000,10 @@ object DecisionTree extends Serializable with Logging { val leftImpurity = leftNodeAgg.calculate() // Note: 0 if count = 0 val rightImpurity = rightNodeAgg.calculate() - /* + println(s"calculateGainForSplit") println(s"\t leftImpurity = $leftImpurity, leftNodeAgg: $leftNodeAgg") println(s"\t rightImpurity = $rightImpurity, rightNodeAgg: $rightNodeAgg") - */ val leftWeight = leftCount / totalCount.toDouble val rightWeight = rightCount / totalCount.toDouble @@ -1265,20 +1258,16 @@ object DecisionTree extends Serializable with Logging { case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } - val binMultiplier = if (metadata.isMulticlassWithCategoricalFeatures) { - 2 - } else { - 1 - } val agg = Array.fill[Array[ImpurityAggregator]](numNodes, metadata.numFeatures)( new Array[ImpurityAggregator](0)) var nodeIndex = 0 while (nodeIndex < numNodes) { var featureIndex = 0 while (featureIndex < metadata.numFeatures) { - var binIndex = 0 + val binMultiplier = if (metadata.isUnordered(featureIndex)) 2 else 1 val effNumBins = metadata.numBins(featureIndex) * binMultiplier agg(nodeIndex)(featureIndex) = new Array[ImpurityAggregator](effNumBins) + var binIndex = 0 while (binIndex < effNumBins) { agg(nodeIndex)(featureIndex)(binIndex) = impurityAggregator.newAggregator binIndex += 1 From 8464a6efd644daf9954ba43c9790ec304f94e029 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 12:26:57 -0700 Subject: [PATCH 12/34] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable --- .../spark/mllib/tree/DecisionTree.scala | 121 ++++-------------- .../spark/mllib/tree/impl/TimeTracker.scala | 75 +++++++++++ .../spark/mllib/tree/impl/TreePoint.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 19 --- 4 files changed, 101 insertions(+), 116 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 17e7a3e65db60..f6cd897a0d760 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.tree -import java.util.Calendar import scala.collection.JavaConverters._ @@ -29,45 +28,12 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.TreePoint +import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -class TimeTracker { - - var tmpTime: Long = Calendar.getInstance().getTimeInMillis - - def reset(): Unit = { - tmpTime = Calendar.getInstance().getTimeInMillis - } - - def elapsed(): Long = { - Calendar.getInstance().getTimeInMillis - tmpTime - } - - var initTime: Long = 0 // Data retag and cache - var findSplitsBinsTime: Long = 0 - var extractNodeInfoTime: Long = 0 - var extractInfoForLowerLevelsTime: Long = 0 - var findBestSplitsTime: Long = 0 - var findBinsForLevelTime: Long = 0 - var binAggregatesTime: Long = 0 - var chooseSplitsTime: Long = 0 - - override def toString: String = { - s"DecisionTree timing\n" + - s"initTime: $initTime\n" + - s"findSplitsBinsTime: $findSplitsBinsTime\n" + - s"extractNodeInfoTime: $extractNodeInfoTime\n" + - s"extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n" + - s"findBestSplitsTime: $findBestSplitsTime\n" + - s"findBinsForLevelTime: $findBinsForLevelTime\n" + - s"binAggregatesTime: $binAggregatesTime\n" + - s"chooseSplitsTime: $chooseSplitsTime\n" - } -} /** * :: Experimental :: @@ -90,26 +56,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo def train(input: RDD[LabeledPoint]): DecisionTreeModel = { val timer = new TimeTracker() - timer.reset() + timer.start("total") + + timer.start("init") // Cache input RDD for speedup during multiple passes. val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) - - timer.initTime += timer.elapsed() - timer.reset() + timer.stop("init") // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. + timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) - timer.findSplitsBinsTime += timer.elapsed() - - timer.reset() + timer.start("init") val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) - timer.initTime += timer.elapsed() + timer.stop("init") // depth of the decision tree val maxDepth = strategy.maxDepth @@ -166,21 +132,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. - timer.reset() + timer.start("findBestSplits") val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer) - timer.findBestSplitsTime += timer.elapsed() + timer.stop("findBestSplits") for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - timer.reset() + timer.start("extractNodeInfo") // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) - timer.extractNodeInfoTime += timer.elapsed() - timer.reset() + timer.stop("extractNodeInfo") + timer.start("extractInfoForLowerLevels") // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) - timer.extractInfoForLowerLevelsTime += timer.elapsed() + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } require(math.pow(2, level) == splitsStatsForLevel.length) @@ -194,8 +160,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } - println(timer) - logDebug("#####################################") logDebug("Extracting tree model") logDebug("#####################################") @@ -205,6 +169,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + //println(timer) // Print internal timing info. + new DecisionTreeModel(topNode, strategy.algo) } @@ -252,7 +220,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // noting the parents filters for the child nodes val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - //println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}") for (filter <- filters(nodeIndex)) { logDebug("Filter = " + filter) } @@ -491,7 +458,6 @@ object DecisionTree extends Serializable with Logging { maxLevelForSingleGroup: Int, timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation - //println(s"findBestSplits: level = $level") if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups @@ -681,7 +647,6 @@ object DecisionTree extends Serializable with Logging { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, treePoint) - //println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}") val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { // Mark one bin as -1 is sufficient. @@ -699,12 +664,12 @@ object DecisionTree extends Serializable with Logging { arr } - timer.reset() + timer.start("findBinsForLevel") // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - timer.findBinsForLevelTime += timer.elapsed() + timer.stop("findBinsForLevel") /** * Increment aggregate in location for (node, feature, bin, label). @@ -752,7 +717,6 @@ object DecisionTree extends Serializable with Logging { label: Double, agg: Array[Double], rightChildShift: Int): Unit = { - //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") // Find the bin index for this feature. val arrIndex = 1 + numFeatures * nodeIndex + featureIndex val featureValue = arr(arrIndex).toInt @@ -830,10 +794,6 @@ object DecisionTree extends Serializable with Logging { // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (level == 1) { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - //println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}") - } if (isSampleValidForNode) { // actual class label val label = arr(0) @@ -954,39 +914,15 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } - timer.reset() // Calculate bin aggregates. + timer.start("binAggregates") val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } + timer.stop("binAggregates") logDebug("binAggregates.length = " + binAggregates.length) - timer.binAggregatesTime += timer.elapsed() - //2 * numClasses * numBins * numFeatures * numNodes for unordered features. - // (left/right, node, feature, bin, label) - /* - println(s"binAggregates:") - for (i <- Range(0,2)) { - for (n <- Range(0,numNodes)) { - for (f <- Range(0,numFeatures)) { - for (b <- Range(0,4)) { - for (c <- Range(0,numClasses)) { - val idx = i * numClasses * numBins * numFeatures * numNodes + - n * numClasses * numBins * numFeatures + - f * numBins * numFeatures + - b * numFeatures + - c - if (binAggregates(idx) != 0) { - println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}") - } - } - } - } - } - } - */ - /** * Calculates the information gain for all splits based upon left/right split aggregates. * @param leftNodeAgg left node aggregates @@ -1027,7 +963,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftTotalCount + rightTotalCount if (totalCount == 0) { // Return arbitrary prediction. - //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) } @@ -1054,9 +989,6 @@ object DecisionTree extends Serializable with Logging { } val predict = indexOfLargestArrayElement(leftRightCounts) - if (predict == 0 && featureIndex == 0 && splitIndex == 0) { - //println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}") - } val prob = leftRightCounts(predict) / totalCount val leftImpurity = if (leftTotalCount == 0) { @@ -1209,7 +1141,6 @@ object DecisionTree extends Serializable with Logging { } splitIndex += 1 } - //println(s"found Agg: $TMPDEBUG") } def findAggForRegression( @@ -1369,7 +1300,6 @@ object DecisionTree extends Serializable with Logging { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") } splitIndex += 1 } @@ -1414,7 +1344,7 @@ object DecisionTree extends Serializable with Logging { } } - timer.reset() + timer.start("chooseSplits") // Calculate best splits for all nodes at a given level val bestSplits = new Array[(Split, InformationGainStats)](numNodes) @@ -1427,10 +1357,9 @@ object DecisionTree extends Serializable with Logging { val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) - //println(s"bestSplits(node:$node): ${bestSplits(node)}") node += 1 } - timer.chooseSplitsTime += timer.elapsed() + timer.stop("chooseSplits") bestSplits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..251b9c2f0eaeb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] +class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val tmpTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = tmpTime + } + + /** + * Stops a timer and returns the elapsed time in nanoseconds. + */ + def stop(timerLabel: String): Long = { + val tmpTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = tmpTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed + } + + /** + * Print all timing results. + */ + override def toString: String = { + s"Timing\n" + + totals.map { case (label, elapsed) => + s" $label: $elapsed" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index f3b5dce041207..bd2cdae968124 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD * of size (numFeatures, numBins). * TODO: ADD DOC */ -private[tree] class TreePoint(val label: Double, val features: Array[Int]) { +private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable { } private[tree] object TreePoint { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5666064647a10..8708d2392a825 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -689,7 +689,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } test("stump with categorical variables for multiclass classification, with just enough bins") { - println("START: stump with categorical variables for multiclass classification, with just enough bins") val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -701,22 +700,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, Array[List[Filter]](), splits, bins, 10) - println(s"splits:") - for (feature <- Range(0,splits.size)) { - for (i <- Range(0,3)) { - println(s" f:$feature [$i]: ${splits(feature)(i)}") - } - } - println(s"bins:") - for (feature <- Range(0,bins.size)) { - for (i <- Range(0,4)) { - println(s" f:$feature [$i]: ${bins(feature)(i)}") - } - } - println(s"bestSplits:") - bestSplits.foreach { x => - println(s"\t $x") - } assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -729,11 +712,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.rightImpurity === 0) val model = DecisionTree.train(input, strategy) - println(model) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - println("END: stump with categorical variables for multiclass classification, with just enough bins") } test("stump with continuous variables for multiclass classification") { From e66f1b1cb2252dab1f847f2c24623baab40627fc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 12:58:22 -0700 Subject: [PATCH 13/34] TreePoint * Updated doc * Made some methods private Changed timer to report time in seconds. --- .../spark/mllib/tree/DecisionTree.scala | 25 +++-------- .../spark/mllib/tree/impl/TimeTracker.scala | 10 ++--- .../spark/mllib/tree/impl/TreePoint.scala | 44 ++++++++++++++----- .../spark/mllib/tree/DecisionTreeSuite.scala | 10 ++--- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f6cd897a0d760..1845fc061e52e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -24,12 +24,12 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint} -import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -59,8 +59,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("total") - timer.start("init") // Cache input RDD for speedup during multiple passes. + timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) timer.stop("init") @@ -74,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("numBins = " + numBins) timer.start("init") - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins) + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache() timer.stop("init") // depth of the decision tree @@ -90,7 +90,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = retaggedInput.take(1)(0).features.size + val numFeatures = treeInput.take(1)(0).features.size // Calculate level for single group construction @@ -118,10 +118,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * still survived the filters of the parent nodes. */ - var findBestSplitsTime: Long = 0 - var extractNodeInfoTime: Long = 0 - var extractInfoForLowerLevelsTime: Long = 0 - var level = 0 var break = false while (level <= maxDepth && !break) { @@ -618,8 +614,6 @@ object DecisionTree extends Serializable with Logging { true } - // TODO: REMOVED findBin() - /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: @@ -664,11 +658,9 @@ object DecisionTree extends Serializable with Logging { arr } - timer.start("findBinsForLevel") - // Find feature bins for all nodes at a level. + timer.start("findBinsForLevel") val binMappedRDD = input.map(x => findBinsForLevel(x)) - timer.stop("findBinsForLevel") /** @@ -1126,7 +1118,6 @@ object DecisionTree extends Serializable with Logging { val rightChildShift = numClasses * numBins * numFeatures var splitIndex = 0 - var TMPDEBUG = 0.0 while (splitIndex < numBins - 1) { var classIndex = 0 while (classIndex < numClasses) { @@ -1136,7 +1127,6 @@ object DecisionTree extends Serializable with Logging { val rightBinValue = binData(rightChildShift + shift + classIndex) leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - TMPDEBUG += leftBinValue + rightBinValue classIndex += 1 } splitIndex += 1 @@ -1344,9 +1334,8 @@ object DecisionTree extends Serializable with Logging { } } - timer.start("chooseSplits") - // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index 251b9c2f0eaeb..60ecd7c589574 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -45,9 +45,9 @@ class TimeTracker extends Serializable { } /** - * Stops a timer and returns the elapsed time in nanoseconds. + * Stops a timer and returns the elapsed time in seconds. */ - def stop(timerLabel: String): Long = { + def stop(timerLabel: String): Double = { val tmpTime = System.nanoTime() if (!starts.contains(timerLabel)) { throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + @@ -60,16 +60,16 @@ class TimeTracker extends Serializable { } else { totals(timerLabel) = elapsed } - elapsed + elapsed / 1e9 } /** - * Print all timing results. + * Print all timing results in seconds. */ override def toString: String = { s"Timing\n" + totals.map { case (label, elapsed) => - s" $label: $elapsed" + s" $label: ${elapsed / 1e9}" }.mkString("\n") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index bd2cdae968124..fb9691ba480ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -22,16 +22,36 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.model.Bin import org.apache.spark.rdd.RDD + /** - * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] - * of size (numFeatures, numBins). - * TODO: ADD DOC + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param features Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. */ private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable { } + private[tree] object TreePoint { + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param strategy DecisionTree training info, used for dataset metadata. + * @param bins Bins for features, of size (numFeatures, numBins). + * @return TreePoint dataset representation + */ def convertToTreeRDD( input: RDD[LabeledPoint], strategy: Strategy, @@ -42,7 +62,12 @@ private[tree] object TreePoint { } } - def labeledPointToTreePoint( + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def labeledPointToTreePoint( labeledPoint: LabeledPoint, isMulticlassClassification: Boolean, bins: Array[Array[Bin]], @@ -77,16 +102,11 @@ private[tree] object TreePoint { /** * Find bin for one (labeledPoint, feature). * - * @param featureIndex - * @param labeledPoint - * @param isFeatureContinuous * @param isUnorderedFeature (only applies if feature is categorical) - * @param bins Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] - * of size (numFeatures, numBins). - * @param categoricalFeaturesInfo - * @return + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ - def findBin( + private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous: Boolean, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8708d2392a825..1019adc1478b1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -696,6 +696,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val model = DecisionTree.train(input, strategy) + validateClassifier(model, arr, 1.0) + assert(model.numNodes === 3) + assert(model.depth === 1) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0, @@ -710,11 +715,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = bestSplits(0)._2 assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) - - val model = DecisionTree.train(input, strategy) - validateClassifier(model, arr, 1.0) - assert(model.numNodes === 3) - assert(model.depth === 1) } test("stump with continuous variables for multiclass classification") { From d03608949e19c53596b4f6cc09d9f68011184d68 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 13:07:14 -0700 Subject: [PATCH 14/34] Print timing info to logDebug. --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- .../org/apache/spark/mllib/tree/impl/TimeTracker.scala | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1845fc061e52e..89a0464f39606 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -167,7 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.stop("total") - //println(timer) // Print internal timing info. + logDebug("Internal timing for DecisionTree:") + logDebug(s"$timer") new DecisionTreeModel(topNode, strategy.algo) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index 60ecd7c589574..bae903a94ab17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -67,9 +67,8 @@ class TimeTracker extends Serializable { * Print all timing results in seconds. */ override def toString: String = { - s"Timing\n" + - totals.map { case (label, elapsed) => - s" $label: ${elapsed / 1e9}" - }.mkString("\n") + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") } } From 430d782294a08f63535e2ecce167703021e1fe44 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 16:09:14 -0700 Subject: [PATCH 15/34] Added more debug info on binning error. Added some docs. --- .../mllib/tree/configuration/Strategy.scala | 43 +++++++++---------- .../spark/mllib/tree/impl/TreePoint.scala | 6 ++- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f31a503608b22..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,6 +85,10 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ private[tree] def assertValid(): Unit = { algo match { case Classification => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index fb9691ba480ec..8ba72bc32cb09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -181,7 +181,8 @@ private[tree] object TreePoint { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") + throw new UnknownError("No bin was found for continuous feature." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") } binIndex } else { @@ -192,7 +193,8 @@ private[tree] object TreePoint { sequentialBinSearchForOrderedCategoricalFeature() } if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") + throw new UnknownError("No bin was found for categorical feature." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") } binIndex } From 26d10dd58ee218102bd205c1e6d68fda5a45cf4b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Aug 2014 17:44:08 -0700 Subject: [PATCH 16/34] Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala. --- .../spark/mllib/tree/DecisionTree.scala | 39 +++---------------- .../spark/mllib/tree/model/Filter.scala | 28 ------------- 2 files changed, 6 insertions(+), 61 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 919a25b65e9b5..af306d57d88cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -62,7 +62,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("total") - // Cache input RDD for speedup during multiple passes. timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) logDebug("algo = " + strategy.algo) @@ -77,6 +76,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("numBins = " + numBins) timer.start("init") + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache() timer.stop("init") @@ -84,10 +85,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 - // Initialize an array to hold filters applied to points for each node. - //val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list. - //filters(0) = List() // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) @@ -118,9 +115,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - * the sample is only used for the split calculation at the node if the sampled would have - * still survived the filters of the parent nodes. + * Each data sample is handled by a particular node at that level (or it reaches a leaf + * beforehand and is not used in later levels. */ var level = 0 @@ -169,7 +165,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. - println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}") val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) if (allLeaf) { @@ -237,8 +232,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) // noting the parent impurities parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) i += 1 } } @@ -461,7 +454,6 @@ object DecisionTree extends Serializable with Logging { * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @param unorderedFeatures Set of unordered (categorical) features. * @return array (over nodes) of splits with best split for each node at a given level. - * TODO: UPDATE DOC */ protected[tree] def findBestSplits( input: RDD[TreePoint], @@ -512,7 +504,6 @@ object DecisionTree extends Serializable with Logging { * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. - * TODO: UPDATE DOC */ private def findBestSplitsPerGroup( input: RDD[TreePoint], @@ -539,7 +530,7 @@ object DecisionTree extends Serializable with Logging { * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, * is ordered (read ordering for categorical variables in the findSplitsBins method), * we exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. @@ -660,7 +651,6 @@ object DecisionTree extends Serializable with Logging { * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * TODO: UPDATE DOC */ def updateBinForOrderedFeature( treePoint: TreePoint, @@ -681,13 +671,12 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. * @param rightChildShift Offset for right side of agg. - * TODO: UPDATE DOC - * TODO: Make arg order same as for ordered feature. */ def updateBinForUnorderedFeature( nodeIndex: Int, @@ -695,7 +684,6 @@ object DecisionTree extends Serializable with Logging { treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { - //println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.") val featureValue = treePoint.features(featureIndex) // Update the left or right count for one bin. val aggShift = @@ -780,7 +768,6 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { - // TODO: Move stuff outside loop. val label = treePoint.label // Iterate over all features. var featureIndex = 0 @@ -791,9 +778,6 @@ object DecisionTree extends Serializable with Logging { 3 * numBins * numFeatures * nodeIndex + 3 * numBins * featureIndex + 3 * binIndex - if (aggIndex >= agg.size) { - println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures") - } agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label agg(aggIndex + 2) = agg(aggIndex + 2) + label * label @@ -1025,7 +1009,6 @@ object DecisionTree extends Serializable with Logging { * Element i (i = 1, ..., numSplits - 1) is set to be * the cumulative sum (from right) over binData for bins * numBins - 1, ..., numBins - 1 - i. - * TODO: We could avoid doing one of these cumulative sums. */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], @@ -1196,16 +1179,6 @@ object DecisionTree extends Serializable with Logging { } else { featureCategories } - /* - val isSpaceSufficientForAllCategoricalSplits = - numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { - // Ordered features - featureCategories - } - */ } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala deleted file mode 100644 index 2deaf4ae8dcab..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -/** - * Filter specifying a split and type of comparison to be applied on features - * @param split split specifying the feature index, type and threshold - * @param comparison integer specifying <,=,> - */ -private[tree] case class Filter(split: Split, comparison: Int) { - // Comparison -1,0,1 signifies <.=,> - override def toString = " split = " + split + "comparison = " + comparison -} From 9c833639ec935a1f372ea0655012259957d8778b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 17 Aug 2014 20:23:17 -0700 Subject: [PATCH 17/34] partial merge but not done yet --- .../spark/mllib/tree/DecisionTree.scala | 881 ++---------------- .../tree/impl/DecisionTreeMetadata.scala | 4 +- .../spark/mllib/tree/impl/TimeTracker.scala | 21 - .../spark/mllib/tree/impl/TreePoint.scala | 60 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 135 +-- 5 files changed, 86 insertions(+), 1015 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2a48287abc882..0a3cfcfe4a104 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -62,16 +62,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.start("total") timer.start("init") -<<<<<<< HEAD - val retaggedInput = input.retag(classOf[LabeledPoint]) - val metadata = LearningMetadata.buildMetadata(retaggedInput, strategy) - timer.stop("init") - -======= val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) ->>>>>>> upstream/master logDebug("algo = " + strategy.algo) logDebug("maxBins = " + metadata.maxBins) @@ -83,7 +76,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.stop("findSplitsBins") logDebug("numBins = " + numBins) -<<<<<<< HEAD /* println(s"splits:") for (f <- Range(0, splits.size)) { @@ -99,50 +91,32 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } */ - timer.start("init") - // Bin feature values (TreePoint representation). - // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata).cache() - timer.stop("init") -======= // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) ->>>>>>> upstream/master val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree -<<<<<<< HEAD val maxNumNodes = DecisionTree.maxNodesInLevel(maxDepth + 1) - 1 -======= - val maxNumNodes = (2 << maxDepth) - 1 ->>>>>>> upstream/master + // TODO: CHECK val maxNumNodes = (2 << maxDepth) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) -<<<<<<< HEAD - val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? - nodesInTree(0) = true -======= ->>>>>>> upstream/master + // TODO: DO THIS OPTIMIZATION: + // val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? + // nodesInTree(0) = true // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") -<<<<<<< HEAD // TODO: Calculate numElementsPerNode in metadata (more precisely) - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) -======= val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) ->>>>>>> upstream/master logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -174,7 +148,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") -<<<<<<< HEAD val splitsStatsForLevel: Array[(Split, InformationGainStats)] = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) @@ -185,14 +158,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo /*println(s"splitsStatsForLevel: index=$index") println(s"\t split: ${nodeSplitStats._1}") println(s"\t gain stats: ${nodeSplitStats._2}")*/ -======= - val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) - timer.stop("findBestSplits") - - val levelNodeIndexOffset = (1 << level) - 1 - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { ->>>>>>> upstream/master val nodeIndex = levelNodeIndexOffset + index val isLeftChild = level != 0 && nodeIndex % 2 == 1 val parentNodeIndex = if (isLeftChild) { // -1 for root node @@ -200,9 +165,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } else { (nodeIndex - 2) / 2 } -<<<<<<< HEAD // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) // TODO: Use above check to skip unused branch of tree + // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") val split = nodeSplitStats._1 @@ -213,12 +178,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes(nodeIndex) = node timer.stop("extractNodeInfo") -======= - // Extract info for this node (index) at the current level. - timer.start("extractNodeInfo") - extractNodeInfo(nodeSplitStats, level, index, nodes) - timer.stop("extractNodeInfo") ->>>>>>> upstream/master if (level != 0) { // Set parent. if (isLeftChild) { @@ -233,11 +192,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } -<<<<<<< HEAD require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length) -======= - require((1 << level) == splitsStatsForLevel.length) ->>>>>>> upstream/master // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -258,13 +213,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo topNode.build(nodes) timer.stop("total") -<<<<<<< HEAD - - logDebug("Internal timing for DecisionTree:") - logDebug(s"$timer") - - new DecisionTreeModel(topNode, strategy.algo) -======= logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") @@ -272,24 +220,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } - /** - * Extract the decision tree node information for the given tree level and node index - */ - private def extractNodeInfo( - nodeSplitStats: (Split, InformationGainStats), - level: Int, - index: Int, - nodes: Array[Node]): Unit = { - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val nodeIndex = (1 << level) - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node ->>>>>>> upstream/master - } - /** * Extract the decision tree node information for the children of the node */ @@ -299,32 +229,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double]): Unit = { -<<<<<<< HEAD - if (level >= maxDepth) - return - // TODO: Move nodeIndexOffset calc out of function? - val nodeIndexOffset = DecisionTree.maxNodesInLevel(level + 1) - 1 - // 0 corresponds to the left child node and 1 corresponds to the right child node. - var i = 0 - while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = nodeIndexOffset + 2 * index + i - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - i += 1 -======= - if (level >= maxDepth) { return ->>>>>>> upstream/master } - + // TODO: Move nodeIndexOffset calc out of function? val leftNodeIndex = (2 << level) - 1 + 2 * index val leftImpurity = nodeSplitStats._2.leftImpurity logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) @@ -548,26 +456,15 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree -<<<<<<< HEAD * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @param unorderedFeatures Set of unordered (categorical) features. -======= - * @param splits possible splits for all features - * @param bins possible bins for all features - * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. ->>>>>>> upstream/master * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], -<<<<<<< HEAD - metadata: LearningMetadata, -======= metadata: DecisionTreeMetadata, ->>>>>>> upstream/master level: Int, nodes: Array[Node], splits: Array[Array[Split]], @@ -580,11 +477,7 @@ object DecisionTree extends Serializable with Logging { // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. -<<<<<<< HEAD - val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt -======= val numGroups = 1 << level - maxLevelForSingleGroup ->>>>>>> upstream/master logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. @@ -608,14 +501,8 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree -<<<<<<< HEAD * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param unorderedFeatures Set of unordered (categorical) features. -======= - * @param splits possible splits for all features - * @param bins possible bins for all features, indexed as (numFeatures)(numBins) ->>>>>>> upstream/master * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. @@ -623,11 +510,7 @@ object DecisionTree extends Serializable with Logging { private def findBestSplitsPerGroup( input: RDD[TreePoint], parentImpurities: Array[Double], -<<<<<<< HEAD - metadata: LearningMetadata, -======= metadata: DecisionTreeMetadata, ->>>>>>> upstream/master level: Int, nodes: Array[Node], splits: Array[Array[Split]], @@ -664,11 +547,7 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. -<<<<<<< HEAD val numNodes = DecisionTree.maxNodesInLevel(level) / numGroups -======= - val numNodes = (1 << level) / numGroups ->>>>>>> upstream/master logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. @@ -686,24 +565,13 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlass = " + isMulticlass) val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures -<<<<<<< HEAD logDebug("isMulticlassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) -======= - logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) ->>>>>>> upstream/master // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex /** * Get the node index corresponding to this data point. -<<<<<<< HEAD - * This is used during training, mimicking prediction. - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - */ - def predictNodeIndex(node: Node, features: Array[Int]): Int = { -======= * This function mimics prediction, passing an example from the root node down to a node * at the current level being trained; that node's index is returned. * @@ -711,18 +579,13 @@ object DecisionTree extends Serializable with Logging { * Otherwise, last node reachable in tree matching this example. */ def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { ->>>>>>> upstream/master if (node.isLeaf) { node.id } else { val featureIndex = node.split.get.feature val splitLeft = node.split.get.featureType match { case Continuous => { -<<<<<<< HEAD - val binIndex = features(featureIndex) -======= val binIndex = binnedFeatures(featureIndex) ->>>>>>> upstream/master val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] // We do not need to check lowSplit since bins are separated by splits. @@ -730,65 +593,11 @@ object DecisionTree extends Serializable with Logging { } case Categorical => { val featureValue = if (metadata.isUnordered(featureIndex)) { -<<<<<<< HEAD - features(featureIndex) - } else { - val binIndex = features(featureIndex) - bins(featureIndex)(binIndex).category - } - node.split.get.categories.contains(featureValue) - } - case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") - } - if (node.leftNode.isEmpty || node.rightNode.isEmpty) { - // Return index from next layer of nodes to train - if (splitLeft) { - node.id * 2 + 1 // left - } else { - node.id * 2 + 2 // right - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftNode.get, features) - } else { - predictNodeIndex(node.rightNode.get, features) - } - } - } - } - - def nodeIndexToLevel(idx: Int): Int = { - if (idx == 0) { - 0 - } else { - math.floor(math.log(idx) / math.log(2)).toInt - } - } - - // Used for treePointToNodeIndex - val levelOffset = DecisionTree.maxNodesInLevel(level) - 1 - - /** - * Find the node (indexed from 0 at the start of this level) for the given example. - * If the example does not reach this level, returns a value < 0. - */ - def treePointToNodeIndex(treePoint: TreePoint): Int = { - if (level == 0) { - 0 - } else { - val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.features) - // Get index for this level. - globalNodeIndex - levelOffset - } - } - - val rightChildShift = numClasses * numBins * numFeatures * numNodes -======= - binnedFeatures(featureIndex) - } else { - val binIndex = binnedFeatures(featureIndex) - bins(featureIndex)(binIndex).category - } + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } node.split.get.categories.contains(featureValue) } case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") @@ -819,7 +628,7 @@ object DecisionTree extends Serializable with Logging { } // Used for treePointToNodeIndex - val levelOffset = (1 << level) - 1 + val levelOffset = DecisionTree.maxNodesInLevel(level) - 1 /** * Find the node index for the given example. @@ -836,89 +645,6 @@ object DecisionTree extends Serializable with Logging { } } - /** - * Increment aggregate in location for (node, feature, bin, label). - * - * @param treePoint Data point being aggregated. - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def updateBinForOrderedFeature( - treePoint: TreePoint, - agg: Array[Double], - nodeIndex: Int, - featureIndex: Int): Unit = { - // Update the left or right count for one bin. - val aggIndex = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - numClasses * treePoint.binnedFeatures(featureIndex) + - treePoint.label.toInt - agg(aggIndex) += 1 - } - - /** - * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), - * where [bins] ranges over all bins. - * Updates left or right side of aggregate depending on split. - * - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @param treePoint Data point being aggregated. - * @param agg Indexed by (left/right, node, feature, bin, label) - * where label is the least significant bit. - * The left/right specifier is a 0/1 index indicating left/right child info. - * @param rightChildShift Offset for right side of agg. - */ - def updateBinForUnorderedFeature( - nodeIndex: Int, - featureIndex: Int, - treePoint: TreePoint, - agg: Array[Double], - rightChildShift: Int): Unit = { - val featureValue = treePoint.binnedFeatures(featureIndex) - // Update the left or right count for one bin. - val aggShift = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - treePoint.label.toInt - // Find all matching bins and increment their values - val featureCategories = metadata.featureArity(featureIndex) - val numCategoricalBins = (1 << featureCategories - 1) - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - val aggIndex = aggShift + binIndex * numClasses - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(aggIndex) += 1 - } else { - agg(rightChildShift + aggIndex) += 1 - } - binIndex += 1 - } - } - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def binaryOrNotCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - featureIndex += 1 - } - } ->>>>>>> upstream/master val rightChildShift = numClasses * numBins * numFeatures * numNodes @@ -933,23 +659,17 @@ object DecisionTree extends Serializable with Logging { * @param treePoint Data point being aggregated. * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ -<<<<<<< HEAD def someUnorderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], -======= - def multiclassWithCategoricalBinSeqOp( - agg: Array[Double], ->>>>>>> upstream/master - treePoint: TreePoint, - nodeIndex: Int): Unit = { + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint, + nodeIndex: Int): Unit = { val label = treePoint.label // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { if (metadata.isUnordered(featureIndex)) { -<<<<<<< HEAD // Unordered feature - val featureValue = treePoint.features(featureIndex) + val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. // Find all matching bins and increment their values. val numCategoricalBins = metadata.numBins(featureIndex) @@ -964,13 +684,8 @@ object DecisionTree extends Serializable with Logging { } } else { // Ordered feature - val binIndex = treePoint.features(featureIndex) + val binIndex = treePoint.binnedFeatures(featureIndex) agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) -======= - updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) - } else { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) ->>>>>>> upstream/master } featureIndex += 1 } @@ -989,36 +704,21 @@ object DecisionTree extends Serializable with Logging { * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ -<<<<<<< HEAD def orderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint, - nodeIndex: Int): Unit = { -======= - def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { ->>>>>>> upstream/master + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint, + nodeIndex: Int): Unit = { val label = treePoint.label // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // Update count, sum, and sum^2 for one bin. -<<<<<<< HEAD - val binIndex = treePoint.features(featureIndex) + val binIndex = treePoint.binnedFeatures(featureIndex) if (binIndex >= agg(nodeIndex)(featureIndex).size) { throw new RuntimeException( s"binIndex: $binIndex, agg(nodeIndex)(featureIndex).size = ${agg(nodeIndex)(featureIndex).size}") } agg(nodeIndex)(featureIndex)(binIndex).add(label) -======= - val binIndex = treePoint.binnedFeatures(featureIndex) - val aggIndex = - 3 * numBins * numFeatures * nodeIndex + - 3 * numBins * featureIndex + - 3 * binIndex - agg(aggIndex) += 1 - agg(aggIndex + 1) += label - agg(aggIndex + 2) += label * label ->>>>>>> upstream/master featureIndex += 1 } } @@ -1037,52 +737,27 @@ object DecisionTree extends Serializable with Logging { * Ordered features: numNodes * numFeatures * numBins. * Unordered features: (2 * numNodes) * numFeatures * numBins. * Size for regression: -<<<<<<< HEAD * numNodes * numFeatures * numBins. * @param treePoint Data point being aggregated. * @return agg */ def binSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { - val nodeIndex = treePointToNodeIndex(treePoint) - if (nodeIndex >= 0) { // Otherwise, example does not reach this level. - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, treePoint, nodeIndex) - } else { - someUnorderedBinSeqOp(agg, treePoint, nodeIndex) -======= - * 3 * numBins * numFeatures * numNodes. - * @param treePoint Data point being aggregated. - * @return agg - */ - def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { val nodeIndex = treePointToNodeIndex(treePoint) // If the example does not reach this level, then nodeIndex < 0. // If the example reaches this level but is handled in a different group, // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } else { - binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, treePoint, nodeIndex) } else { - regressionBinSeqOp(agg, treePoint, nodeIndex) ->>>>>>> upstream/master + someUnorderedBinSeqOp(agg, treePoint, nodeIndex) } } agg } -<<<<<<< HEAD -======= - // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) - logDebug("binAggregateLength = " + binAggregateLength) - ->>>>>>> upstream/master /** * Combines the aggregates from partitions. * @param agg1 Array containing aggregates from one or more partitions @@ -1109,16 +784,12 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregates. -<<<<<<< HEAD - timer.start("binAggregates") + timer.start("aggregation") val binAggregates = { val initAgg = getEmptyBinAggregates(metadata, numNodes) input.aggregate(initAgg)(binSeqOp, binCombOp) } - timer.stop("binAggregates") - - logDebug("binAggregates.length = " + binAggregates.length) - + timer.stop("aggregation") /* println("binAggregates:") for (n <- Range(0, binAggregates.size)) { @@ -1131,6 +802,7 @@ object DecisionTree extends Serializable with Logging { */ // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) val nodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 // Iterating over all nodes at this level @@ -1142,7 +814,6 @@ object DecisionTree extends Serializable with Logging { //logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("parent node impurity = " + parentNodeImpurity) - val (bestFeatureIndex, bestSplitIndex, bestGain) = binsToBestSplit(binAggregates(nodeIndex), parentNodeImpurity, level, metadata) bestSplits(nodeIndex) = (splits(bestFeatureIndex)(bestSplitIndex), bestGain) @@ -1164,11 +835,11 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: ImpurityAggregator, - rightNodeAgg: ImpurityAggregator, - topImpurity: Double, - level: Int, - metadata: LearningMetadata): InformationGainStats = { + leftNodeAgg: ImpurityAggregator, + rightNodeAgg: ImpurityAggregator, + topImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftNodeAgg.count val rightCount = rightNodeAgg.count @@ -1221,147 +892,21 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg: Array[Array[ImpurityAggregator]], nodeImpurity: Double, level: Int, - metadata: LearningMetadata): Array[Array[InformationGainStats]] = { + metadata: DecisionTreeMetadata): Array[Array[InformationGainStats]] = { val gains = new Array[Array[InformationGainStats]](metadata.numFeatures) - for (featureIndex <- 0 until metadata.numFeatures) { + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { val numSplitsForFeature = metadata.numSplits(featureIndex) gains(featureIndex) = new Array[InformationGainStats](numSplitsForFeature) - for (splitIndex <- 0 until numSplitsForFeature) { + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), rightNodeAgg(featureIndex)(splitIndex), nodeImpurity, level, metadata) -======= - timer.start("aggregation") - val binAggregates = { - input.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) - } - timer.stop("aggregation") - logDebug("binAggregates.length = " + binAggregates.length) - - /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftNodeAgg left node aggregates for this (feature, split) - * @param rightNodeAgg right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Double], - rightNodeAgg: Array[Double], - topImpurity: Double): InformationGainStats = { - if (metadata.isClassification) { - val leftTotalCount = leftNodeAgg.sum - val rightTotalCount = rightNodeAgg.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) - classIndex += 1 - } - metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } - - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } - - // Sum of count for each label - val leftrightNodeAgg: Array[Double] = - leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => - leftCount + rightCount - } - - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 - } - - val predict = indexOfLargestArrayElement(leftrightNodeAgg) - val prob = leftrightNodeAgg(predict) / totalCount - - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(leftNodeAgg, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(rightNodeAgg, rightTotalCount) - } - - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - - } else { - // Regression - - val leftCount = leftNodeAgg(0) - val leftSum = leftNodeAgg(1) - val leftSumSquares = leftNodeAgg(2) - - val rightCount = rightNodeAgg(0) - val rightSum = rightNodeAgg(1) - val rightSumSquares = rightNodeAgg(2) - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - metadata.impurity.calculate(count, sum, sumSquares) - } - } - - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, - Double.MinValue, leftSum / leftCount) - } - - val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) - - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) ->>>>>>> upstream/master + splitIndex += 1 } + featureIndex += 1 } gains } @@ -1381,8 +926,8 @@ object DecisionTree extends Serializable with Logging { * TODO: Extract in-place. */ def extractLeftRightNodeAggregates( - nodeAggregates: Array[Array[ImpurityAggregator]], - metadata: LearningMetadata): (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { + nodeAggregates: Array[Array[ImpurityAggregator]], + metadata: DecisionTreeMetadata): (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { val numClasses = metadata.numClasses val numFeatures = metadata.numFeatures @@ -1392,61 +937,24 @@ object DecisionTree extends Serializable with Logging { * Indexes binData as (feature, split, class) with class as the least significant bit. * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value */ -<<<<<<< HEAD def findAggForUnorderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { // TODO: Don't pass in featureIndex; use index before call. // Note: numBins = numSplits for unordered features. val numBins = metadata.numBins(featureIndex) leftNodeAgg(featureIndex) = binData(featureIndex).slice(0, numBins) rightNodeAgg(featureIndex) = binData(featureIndex).slice(numBins, 2 * numBins) } -======= - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - - - /** - * The input binData is indexed as (feature, bin, class). - * This computes cumulative sums over splits. - * Each (feature, class) pair is handled separately. - * Note: numSplits = numBins - 1. - * @param leftNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 0, ..., numSplits - 2) is set to be - * the cumulative sum (from left) over binData for bins 0, ..., i. - * @param rightNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 1, ..., numSplits - 1) is set to be - * the cumulative sum (from right) over binData for bins - * numBins - 1, ..., numBins - 1 - i. - */ - def findAggForOrderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - - var classIndex = 0 - while (classIndex < numClasses) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (numClasses * (numBins - 1)) + classIndex) - classIndex += 1 - } ->>>>>>> upstream/master /** * For ordered features (regression and classification with ordered features). * The input binData is indexed as (feature, bin, class). * This computes cumulative sums over splits. * Each (feature, class) pair is handled separately. - * Note: numSplits = numBins - 1. + * TODO: UPDATE DOC: Note: numSplits = numBins - 1. * @param leftNodeAgg Each (feature, class) slice is an array over splits. * Element i (i = 0, ..., numSplits - 2) is set to be * the cumulative sum (from left) over binData for bins 0, ..., i. @@ -1457,13 +965,12 @@ object DecisionTree extends Serializable with Logging { * TODO: We could avoid doing one of these cumulative sums. */ def findAggForOrderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { // TODO: Don't pass in featureIndex; use index before call. - val numSplits = metadata.numSplits(featureIndex) leftNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) rightNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) @@ -1488,20 +995,19 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } } else { // ordered categorical feature - /* TODO: This is a temp fix. - * Eventually, for ordered categorical features, change splits and bins to be - * for individual categories instead of running totals over a pre-defined category - * ordering. Then, we could choose the ordering in this function, tailoring it - * to this particular node. - */ - var splitIndex = 0 + /* TODO: This is a temp fix. + * Eventually, for ordered categorical features, change splits and bins to be + * for individual categories instead of running totals over a pre-defined category + * ordering. Then, we could choose the ordering in this function, tailoring it + * to this particular node. + */ + var splitIndex = 0 while (splitIndex < numSplits) { // no need to clone since no cumulative sum is needed leftNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex) rightNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex + 1) splitIndex += 1 } -<<<<<<< HEAD } } @@ -1518,69 +1024,15 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } } else { // Regression - var featureIndex = 0 + var featureIndex = 0 while (featureIndex < numFeatures) { findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 -======= - } - - if (metadata.isClassification) { - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } else { - // Regression - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } - - /** - * Calculates information gain for all nodes splits. - */ - def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - - var featureIndex = 0 - while (featureIndex < numFeatures) { - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = - calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), - rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) - splitIndex += 1 - } - featureIndex += 1 ->>>>>>> upstream/master } } (leftNodeAgg, rightNodeAgg) } -<<<<<<< HEAD /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1588,10 +1040,10 @@ object DecisionTree extends Serializable with Logging { * @return tuple (best feature index, best split index, information gain) */ def binsToBestSplit( - nodeAggregates: Array[Array[ImpurityAggregator]], - nodeImpurity: Double, - level: Int, - metadata: LearningMetadata): (Int, Int, InformationGainStats) = { + nodeAggregates: Array[Array[ImpurityAggregator]], + nodeImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) /* @@ -1599,70 +1051,10 @@ object DecisionTree extends Serializable with Logging { for (f <- Range(0, nodeAggregates.size)) { for (b <- Range(0, nodeAggregates(f).size)) { println(s"nodeAggregates($f)($b): ${nodeAggregates(f)(b)}") -======= - /** - * Get the number of splits for a feature. - */ - def getNumSplitsForFeature(featureIndex: Int): Int = { - if (metadata.isContinuous(featureIndex)) { - numBins - 1 - } else { - // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - if (metadata.isUnordered(featureIndex)) { - (1 << featureCategories - 1) - 1 - } else { - featureCategories - } } } - - /** - * Find the best split for a node. - * @param binData Bin data slice for this node, given by getBinDataForNode. - * @param nodeImpurity impurity of the top node - * @return tuple of split and information gain - */ - def binsToBestSplit( - binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { - - logDebug("node impurity = " + nodeImpurity) - - // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - - // Calculate gains for all splits. - val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - - val (bestFeatureIndex, bestSplitIndex, gainStats) = { - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - // Iterate over features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Iterate over all splits. - var splitIndex = 0 - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - while (splitIndex < numSplitsForFeature) { - val gainStats = gains(featureIndex)(splitIndex) - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex - } - splitIndex += 1 - } - featureIndex += 1 - } - (bestFeatureIndex, bestSplitIndex, bestGainStats) ->>>>>>> upstream/master - } - } -<<<<<<< HEAD */ + // Extract left right node aggregates. val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(nodeAggregates, metadata) @@ -1692,61 +1084,10 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } featureIndex += 1 -======= - - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode - } - } else { - // Regression - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode ->>>>>>> upstream/master } (bestFeatureIndex, bestSplitIndex, bestGainStats) } - -<<<<<<< HEAD (bestFeatureIndex, bestSplitIndex, gainStats) -======= - // Calculate best splits for all nodes at a given level - timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - // Iterating over all nodes at this level - var node = 0 - while (node < numNodes) { - val nodeImpurityIndex = (1 << level) - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) - logDebug("nodeImpurityIndex = " + nodeImpurityIndex) - val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("parent node impurity = " + parentNodeImpurity) - bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) - node += 1 - } - timer.stop("chooseSplits") - - bestSplits ->>>>>>> upstream/master } /** @@ -1773,7 +1114,7 @@ object DecisionTree extends Serializable with Logging { * where the bins are ordered as (numBins left bins, numBins right bins). */ private def getEmptyBinAggregates( - metadata: LearningMetadata, + metadata: DecisionTreeMetadata, numNodes: Int): Array[Array[Array[ImpurityAggregator]]] = { val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) @@ -1825,12 +1166,7 @@ object DecisionTree extends Serializable with Logging { * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] -<<<<<<< HEAD - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree -======= * @param metadata Learning and dataset metadata ->>>>>>> upstream/master * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -1839,45 +1175,13 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], -<<<<<<< HEAD - metadata: LearningMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - - val isMulticlassClassification = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlassClassification) - - val numFeatures = metadata.numFeatures - -======= metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - val count = input.count() - - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size - - val maxBins = metadata.maxBins - val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("numBins = " + numBins) val isMulticlass = metadata.isMulticlass logDebug("isMulticlass = " + isMulticlass) - /* - * Ensure numBins is always greater than the categories. For multiclass classification, - * numBins should be greater than 2^(maxCategories - 1) - 1. - * It's a limitation of the current implementation but a reasonable trade-off since features - * with large number of categories get favored over continuous features. - * - * This needs to be checked here instead of in Strategy since numBins can be determined - * by the number of training examples. - * TODO: Allow this case, where we simply will know nothing about some categories. - */ - if (metadata.featureArity.size > 0) { - val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 - require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + - "in categorical features") - } + val numFeatures = metadata.numFeatures ->>>>>>> upstream/master // Calculate the number of sample for approximate quantile calculation. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) val fraction = if (requiredSamples < metadata.numExamples) { @@ -1892,12 +1196,6 @@ object DecisionTree extends Serializable with Logging { input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() val numSamples = sampledInput.length -<<<<<<< HEAD -======= - val stride: Double = numSamples.toDouble / numBins - logDebug("stride = " + stride) - ->>>>>>> upstream/master metadata.quantileStrategy match { case Sort => val splits = new Array[Array[Split]](numFeatures) @@ -1914,14 +1212,8 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { -<<<<<<< HEAD val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { -======= - // Check whether the feature is continuous. - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { ->>>>>>> upstream/master val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) logDebug("stride = " + stride) @@ -1932,7 +1224,6 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) } -<<<<<<< HEAD bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) for (splitIndex <- 1 until numSplits) { @@ -1956,23 +1247,6 @@ object DecisionTree extends Serializable with Logging { new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(splitIndex) = { if (splitIndex == 0) { -======= - } else { // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - - // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint. - if (metadata.isUnordered(featureIndex)) { - // 2^(maxFeatureValue- 1) - 1 combinations - var index = 0 - while (index < (1 << featureCategories - 1) - 1) { - val categories: List[Double] - = extractMultiClassCategories(index + 1, featureCategories) - splits(featureIndex)(index) - = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(index) = { - if (index == 0) { ->>>>>>> upstream/master new Bin( new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), @@ -2059,27 +1333,6 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - -<<<<<<< HEAD -======= - // Find all bins. - featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { // Bins for categorical variables are already assigned. - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1) { - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Continuous, Double.MinValue) - bins(featureIndex)(index) = bin - } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } - featureIndex += 1 - } ->>>>>>> upstream/master (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index fcfee626d03e2..40a6955e88f76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD /* -TODO: MERGE DOC: * TODO: Add doc about ordered vs. unordered features. * Ensure numBins is always greater than the categories. For multiclass classification, * numBins should be greater than math.pow(2, maxCategories - 1) - 1. @@ -38,9 +37,9 @@ TODO: MERGE DOC: * * This needs to be checked here instead of in Strategy since numBins can be determined * by the number of training examples. - * TODO: Allow this case, where we simply will know nothing about some categories. */ + /** * Learning and dataset metadata for DecisionTree. * @@ -81,7 +80,6 @@ private[tree] class DecisionTreeMetadata( } - private[tree] object DecisionTreeMetadata { def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala index e7074b4a074fc..d215d68c4279e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -25,12 +25,7 @@ import org.apache.spark.annotation.Experimental * Time tracker implementation which holds labeled timers. */ @Experimental -<<<<<<< HEAD -private[tree] -class TimeTracker extends Serializable { -======= private[tree] class TimeTracker extends Serializable { ->>>>>>> upstream/master private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() @@ -40,40 +35,24 @@ private[tree] class TimeTracker extends Serializable { * Starts a new timer, or re-starts a stopped timer. */ def start(timerLabel: String): Unit = { -<<<<<<< HEAD - val tmpTime = System.nanoTime() -======= val currentTime = System.nanoTime() ->>>>>>> upstream/master if (starts.contains(timerLabel)) { throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + s" timerLabel = $timerLabel before that timer was stopped.") } -<<<<<<< HEAD - starts(timerLabel) = tmpTime -======= starts(timerLabel) = currentTime ->>>>>>> upstream/master } /** * Stops a timer and returns the elapsed time in seconds. */ def stop(timerLabel: String): Double = { -<<<<<<< HEAD - val tmpTime = System.nanoTime() -======= val currentTime = System.nanoTime() ->>>>>>> upstream/master if (!starts.contains(timerLabel)) { throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + s" timerLabel = $timerLabel, but that timer was not started.") } -<<<<<<< HEAD - val elapsed = tmpTime - starts(timerLabel) -======= val elapsed = currentTime - starts(timerLabel) ->>>>>>> upstream/master starts.remove(timerLabel) if (totals.contains(timerLabel)) { totals(timerLabel) += elapsed diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index b3ecd8022a26a..54b099fb24b59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -18,10 +18,6 @@ package org.apache.spark.mllib.tree.impl import org.apache.spark.mllib.regression.LabeledPoint -<<<<<<< HEAD -import org.apache.spark.mllib.tree.LearningMetadata -======= ->>>>>>> upstream/master import org.apache.spark.mllib.tree.model.Bin import org.apache.spark.rdd.RDD @@ -38,15 +34,6 @@ import org.apache.spark.rdd.RDD * or any categorical feature used in regression or binary classification. * * @param label Label from LabeledPoint -<<<<<<< HEAD - * @param features Binned feature values. - * Same length as LabeledPoint.features, but values are bin indices. - */ -private[tree] class TreePoint(val label: Double, val features: Array[Int]) extends Serializable { -} - - -======= * @param binnedFeatures Binned feature values. * Same length as LabeledPoint.features, but values are bin indices. */ @@ -54,7 +41,6 @@ private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) extends Serializable { } ->>>>>>> upstream/master private[tree] object TreePoint { /** @@ -62,21 +48,13 @@ private[tree] object TreePoint { * binning feature values in preparation for DecisionTree training. * @param input Input dataset. * @param bins Bins for features, of size (numFeatures, numBins). -<<<<<<< HEAD - * @param metadata DecisionTree training info, used for dataset metadata. -======= - * @param metadata Learning and dataset metadata ->>>>>>> upstream/master + * @param metadata Learning and dataset metadata * @return TreePoint dataset representation */ def convertToTreeRDD( input: RDD[LabeledPoint], bins: Array[Array[Bin]], -<<<<<<< HEAD - metadata: LearningMetadata): RDD[TreePoint] = { -======= metadata: DecisionTreeMetadata): RDD[TreePoint] = { ->>>>>>> upstream/master input.map { x => TreePoint.labeledPointToTreePoint(x, bins, metadata) } @@ -85,24 +63,13 @@ private[tree] object TreePoint { /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). -<<<<<<< HEAD * @param metadata DecisionTree training info, used for dataset metadata. -======= ->>>>>>> upstream/master */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], -<<<<<<< HEAD - metadata: LearningMetadata): TreePoint = { - - val numFeatures = labeledPoint.features.size -======= metadata: DecisionTreeMetadata): TreePoint = { - val numFeatures = labeledPoint.features.size - val numBins = bins(0).size ->>>>>>> upstream/master val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { @@ -114,10 +81,6 @@ private[tree] object TreePoint { new TreePoint(labeledPoint.label, arr) } -<<<<<<< HEAD - -======= ->>>>>>> upstream/master /** * Find bin for one (labeledPoint, feature). * @@ -148,17 +111,9 @@ private[tree] object TreePoint { val highThreshold = bin.highSplit.threshold if ((lowThreshold < feature) && (highThreshold >= feature)) { return mid -<<<<<<< HEAD - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { -======= } else if (lowThreshold >= feature) { right = mid - 1 } else { ->>>>>>> upstream/master left = mid + 1 } } @@ -206,12 +161,8 @@ private[tree] object TreePoint { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1) { -<<<<<<< HEAD - throw new UnknownError("No bin was found for continuous feature." + -======= throw new RuntimeException("No bin was found for continuous feature." + " This error can occur when given invalid data values (such as NaN)." + ->>>>>>> upstream/master s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") } binIndex @@ -223,15 +174,10 @@ private[tree] object TreePoint { sequentialBinSearchForOrderedCategoricalFeature() } if (binIndex == -1) { -<<<<<<< HEAD - throw new UnknownError("No bin was found for categorical feature." + - s" Feature index: $featureIndex. isUnorderedFeature = $isUnorderedFeature." + - s" Feature value: ${labeledPoint.features(featureIndex)}") -======= throw new RuntimeException("No bin was found for categorical feature." + " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") ->>>>>>> upstream/master + s" Feature index: $featureIndex. isUnorderedFeature = $isUnorderedFeature." + + s" Feature value: ${labeledPoint.features(featureIndex)}") } binIndex } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 50eca20a07731..fadeef15b07d2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -17,27 +17,20 @@ package org.apache.spark.mllib.tree -import org.apache.spark.mllib.tree.impl.TreePoint - import scala.collection.JavaConverters._ import org.scalatest.FunSuite -<<<<<<< HEAD -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} -import org.apache.spark.mllib.tree.configuration.Strategy -======= ->>>>>>> upstream/master +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.regression.LabeledPoint + class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -50,15 +43,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length -<<<<<<< HEAD - if (accuracy < requiredAccuracy) { - println(s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - assert(accuracy >= requiredAccuracy) -======= assert(accuracy >= requiredAccuracy, s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") ->>>>>>> upstream/master } def validateRegressor( @@ -79,11 +65,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) @@ -102,16 +84,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) ->>>>>>> upstream/master assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 1) @@ -183,9 +160,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -193,10 +169,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(splits(0).length === 2) assert(bins(0).length === 3) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) ->>>>>>> upstream/master // Check splits. @@ -300,9 +272,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -310,10 +281,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(splits(0).length === 3) assert(bins(0).length === 3) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) ->>>>>>> upstream/master // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -400,9 +367,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -410,10 +376,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(splits(0).length === 9) assert(bins(0).length === 10) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) ->>>>>>> upstream/master // 2^10 - 1 > 100, so categorical variables will be ordered @@ -462,9 +424,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) @@ -473,10 +434,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 2) assert(bins(0).length === 3) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) ->>>>>>> upstream/master val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, new Array[Node](0), splits, bins, 10) @@ -503,14 +460,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, @@ -538,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -552,17 +506,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) -<<<<<<< HEAD val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClassesForClassification = 2, maxBins = 100) - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -585,17 +534,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) -<<<<<<< HEAD val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClassesForClassification = 2, maxBins = 100) - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -619,17 +563,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) -<<<<<<< HEAD val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClassesForClassification = 2, maxBins = 100) - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -653,17 +592,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) -<<<<<<< HEAD val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClassesForClassification = 2, maxBins = 100) - val metadata = LearningMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) -======= - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -683,17 +617,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - // TODO: Decide about testing 2nd level - /* test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) -<<<<<<< HEAD -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -714,11 +643,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) -<<<<<<< HEAD - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters, -======= val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, ->>>>>>> upstream/master splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -726,13 +651,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. -<<<<<<< HEAD - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, - filters, splits, bins, 0) -======= val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 0) ->>>>>>> upstream/master assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -748,7 +668,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } } - */ test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() @@ -757,13 +676,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) -======= ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, @@ -788,10 +703,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) -<<<<<<< HEAD - println(model) -======= ->>>>>>> upstream/master validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -823,11 +734,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) @@ -856,11 +763,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) @@ -886,11 +789,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) @@ -915,11 +814,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) -<<<<<<< HEAD - val metadata = LearningMetadata.buildMetadata(rdd, strategy) -======= val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) ->>>>>>> upstream/master val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) From 61c45093a9ae73b03e7a6737424101e45c5aa123 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 14:02:33 -0700 Subject: [PATCH 18/34] Fixed bugs from merge: missing DT timer call, and numBins setting. Cleaned up DT Suite some. --- .../spark/mllib/tree/DecisionTree.scala | 1 + .../tree/impl/DecisionTreeMetadata.scala | 1 + .../spark/mllib/tree/DecisionTreeSuite.scala | 82 +++++-------------- 3 files changed, 24 insertions(+), 60 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0a3cfcfe4a104..3e316bea38fdf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -823,6 +823,7 @@ object DecisionTree extends Serializable with Logging { nodeIndex += 1 } + timer.stop("chooseSplits") bestSplits } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 40a6955e88f76..fc5e8a8b6123e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -112,6 +112,7 @@ private[tree] object DecisionTreeMetadata { require(k < maxPossibleBins, s"maxBins (= $maxPossibleBins) should be greater than max categories " + s"in categorical features (>= $k)") + numBins(f) = k } } } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index fadeef15b07d2..e911dff8db830 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -60,12 +60,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } - test("split and bin calculation for continuous features") { + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) @@ -73,7 +74,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) } - test("split and bin calculation for binary features") { + test("Binary classification with binary features: split and bin calculation") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -100,32 +101,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0)(0).threshold === Double.MinValue) assert(splits(0)(0).featureType === Categorical) assert(splits(0)(0).categories.length === 1) - //println(s"splits(0)(0).categories: ${splits(0)(0).categories}") assert(splits(0)(0).categories.contains(1.0)) - /* - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(1).categories.contains(1.0)) - */ - assert(splits(1)(0).feature === 1) assert(splits(1)(0).threshold === Double.MinValue) assert(splits(1)(0).featureType === Categorical) assert(splits(1)(0).categories.length === 1) assert(splits(1)(0).categories.contains(0.0)) - /* - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(1).categories.contains(1.0)) - */ // Check bins. assert(bins(0)(0).lowSplit.categories.length === 0) @@ -185,16 +168,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0)(1).categories.contains(0.0)) assert(splits(0)(1).categories.contains(1.0)) - /* - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(0.0)) - assert(splits(0)(2).categories.contains(1.0)) - assert(splits(0)(2).categories.contains(2.0)) - */ - assert(splits(1)(0).feature === 1) assert(splits(1)(0).threshold === Double.MinValue) assert(splits(1)(0).featureType === Categorical) @@ -208,16 +181,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(1).categories.contains(0.0)) assert(splits(1)(1).categories.contains(1.0)) - /* - assert(splits(1)(2).feature === 1) - assert(splits(1)(2).threshold === Double.MinValue) - assert(splits(1)(2).featureType === Categorical) - assert(splits(1)(2).categories.length === 3) - assert(splits(1)(2).categories.contains(0.0)) - assert(splits(1)(2).categories.contains(1.0)) - assert(splits(1)(2).categories.contains(2.0)) - */ - // Check bins. assert(bins(0)(0).lowSplit.categories.length === 0) @@ -260,8 +223,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } - test("split and bin calculations for unordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with unordered categorical features:" + + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -355,8 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("split and bin calculations for ordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with ordered categorical features: split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() assert(arr.length === 3000) val rdd = sc.parallelize(arr) @@ -377,7 +339,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 9) assert(bins(0).length === 10) - // 2^10 - 1 > 100, so categorical variables will be ordered + // 2^10 - 1 > 100, so categorical features will be ordered assert(splits(0)(0).feature === 0) assert(splits(0)(0).threshold === Double.MinValue) @@ -413,7 +375,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("classification stump with all ordered categorical variables") { + test("Binary classification stump with all ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -450,7 +412,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("regression stump with all categorical variables") { + test("Regression stump with 3-ary categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -482,7 +444,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("regression stump with categorical variables of arity 2") { + test("Regression stump with binary categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -502,7 +464,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with fixed label 0 for Gini") { + test("Binary classification stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -530,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) } - test("stump with fixed label 1 for Gini") { + test("Binary classification stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -559,7 +521,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("stump with fixed label 0 for Entropy") { + test("Binary classification stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -588,7 +550,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 0) } - test("stump with fixed label 1 for Entropy") { + test("Binary classification stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -617,7 +579,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("second level node building with/without groups") { + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -669,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } } - test("stump with categorical variables for multiclass classification") { + test("Multiclass classification stump with 3-ary categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -692,7 +654,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } - test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) @@ -708,7 +670,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with 2 continuous variables for binary classification") { + test("Binary classification stump with 2 continuous features") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) @@ -726,7 +688,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.topNode.split.get.feature === 1) } - test("stump with categorical variables for multiclass classification, with just enough bins") { + test("Multiclass classification stump with categorical features, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -757,7 +719,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.rightImpurity === 0) } - test("stump with continuous variables for multiclass classification") { + test("Multiclass classification stump with continuous features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -783,7 +745,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous + categorical variables for multiclass classification") { + test("Multiclass classification stump with continuous + categorical features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -808,7 +770,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.threshold < 2020) } - test("stump with categorical variables for ordered multiclass classification") { + test("Multiclass classification stump with 10-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, From 5f94342bda903aac294ab76ab5ba89eb3751d3ab Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 15:32:04 -0700 Subject: [PATCH 19/34] Added treeAggregate since not yet merged from master. Moved node indexing functions to Node. --- .../spark/mllib/tree/DecisionTree.scala | 46 +++++-------------- .../tree/impl/DecisionTreeMetadata.scala | 10 +++- .../apache/spark/mllib/tree/model/Node.scala | 45 +++++++++++++++++- 3 files changed, 64 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3e316bea38fdf..2ed3ac9652852 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,11 +18,11 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ -import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -100,8 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = DecisionTree.maxNodesInLevel(maxDepth + 1) - 1 - // TODO: CHECK val maxNumNodes = (2 << maxDepth) - 1 + val maxNumNodes = Node.maxNodesInLevel(maxDepth + 1) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) @@ -153,18 +152,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 + val levelNodeIndexOffset = Node.maxNodesInLevel(level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { /*println(s"splitsStatsForLevel: index=$index") println(s"\t split: ${nodeSplitStats._1}") println(s"\t gain stats: ${nodeSplitStats._2}")*/ val nodeIndex = levelNodeIndexOffset + index - val isLeftChild = level != 0 && nodeIndex % 2 == 1 - val parentNodeIndex = if (isLeftChild) { // -1 for root node - (nodeIndex - 1) / 2 - } else { - (nodeIndex - 2) / 2 - } + val isLeftChild = Node.isLeftChild(nodeIndex) + val parentNodeIndex = Node.parentIndex(nodeIndex) // -1 for root node + // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) // TODO: Use above check to skip unused branch of tree @@ -192,7 +188,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } - require(DecisionTree.maxNodesInLevel(level) == splitsStatsForLevel.length) + require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -232,8 +228,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo if (level >= maxDepth) { return } - // TODO: Move nodeIndexOffset calc out of function? - val leftNodeIndex = (2 << level) - 1 + 2 * index + val leftNodeIndex = Node.maxNodesInSubtree(level) + 2 * index val leftImpurity = nodeSplitStats._2.leftImpurity logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) parentImpurities(leftNodeIndex) = leftImpurity @@ -547,7 +542,7 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = DecisionTree.maxNodesInLevel(level) / numGroups + val numNodes = Node.maxNodesInLevel(level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. @@ -619,16 +614,8 @@ object DecisionTree extends Serializable with Logging { } } - def nodeIndexToLevel(idx: Int): Int = { - if (idx == 0) { - 0 - } else { - math.floor(math.log(idx) / math.log(2)).toInt - } - } - // Used for treePointToNodeIndex - val levelOffset = DecisionTree.maxNodesInLevel(level) - 1 + val levelOffset = Node.maxNodesInLevel(level) - 1 /** * Find the node index for the given example. @@ -787,7 +774,7 @@ object DecisionTree extends Serializable with Logging { timer.start("aggregation") val binAggregates = { val initAgg = getEmptyBinAggregates(metadata, numNodes) - input.aggregate(initAgg)(binSeqOp, binCombOp) + input.treeAggregate(initAgg)(binSeqOp, binCombOp) } timer.stop("aggregation") /* @@ -804,7 +791,7 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - val nodeIndexOffset = DecisionTree.maxNodesInLevel(level) - 1 + val nodeIndexOffset = Node.maxNodesInLevel(level) - 1 // Iterating over all nodes at this level var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -1160,7 +1147,6 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, @@ -1366,12 +1352,4 @@ object DecisionTree extends Serializable with Logging { categories } - private[tree] def maxNodesInLevel(level: Int): Int = { - math.pow(2, level).toInt - } - - private[tree] def numUnorderedBins(arity: Int): Int = { - (math.pow(2, arity - 1) - 1).toInt - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index fc5e8a8b6123e..a5e19eeef1a85 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -102,7 +102,7 @@ private[tree] object DecisionTreeMetadata { // Note: The above check is equivalent to checking: // numUnorderedBins = (1 << k - 1) - 1 < maxBins unorderedFeatures.add(f) - numBins(f) = DecisionTree.numUnorderedBins(k) + numBins(f) = numUnorderedBins(k) } else { // TODO: Check the below k <= maxBins. // This used to be k < maxPossibleBins, but <= should work. @@ -129,4 +129,12 @@ private[tree] object DecisionTreeMetadata { strategy.impurity, strategy.quantileCalculationStrategy) } + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + */ + def numUnorderedBins(arity: Int): Int = { + (1 << arity - 1) - 1 + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 0eee6262781c1..5ed8722b534af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -52,8 +52,7 @@ class Node ( */ def build(nodes: Array[Node]): Unit = { - logDebug("building node " + id + " at level " + - (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) @@ -148,3 +147,45 @@ class Node ( } } + +private[tree] object Node { + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = { + math.floor(math.log(nodeIndex + 1) / math.log(2)).toInt + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex != 0 && nodeIndex % 2 == 1 + + /** + * Get the parent index of the given node, or -1 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = { + if (isLeftChild(nodeIndex)) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + + } + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + private[tree] def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the maximum number of nodes which can be in or above the given level of the tree + * (i.e., for the entire subtree from the root to this level). + * @param level Level of tree (0 = root). + */ + private[tree] def maxNodesInSubtree(level: Int): Int = (2 << level) - 1 + +} From a40f8f1b4c110de70dfc713c109b998a58c44b14 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 18 Aug 2014 22:44:39 -0700 Subject: [PATCH 20/34] Changed nodes to be indexed from 1. Tests work. --- .../spark/mllib/tree/DecisionTree.scala | 191 ++++++------------ .../tree/impl/DecisionTreeMetadata.scala | 10 +- .../apache/spark/mllib/tree/model/Node.scala | 71 +++---- .../spark/mllib/tree/DecisionTreeSuite.scala | 58 +++--- 4 files changed, 132 insertions(+), 198 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2ed3ac9652852..2d80821c79d2b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -76,21 +76,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo timer.stop("findSplitsBins") logDebug("numBins = " + numBins) - /* - println(s"splits:") - for (f <- Range(0, splits.size)) { - for (s <- Range(0, splits(f).size)) { - println(s" splits($f)($s): ${splits(f)(s)}") - } - } - println(s"bins:") - for (f <- Range(0, bins.size)) { - for (s <- Range(0, bins(f).size)) { - println(s" bins($f)($s): ${bins(f)(s)}") - } - } - */ - // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) @@ -99,22 +84,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth - // the max number of nodes possible given the depth of the tree - val maxNumNodes = Node.maxNodesInLevel(maxDepth + 1) - 1 + // the max number of nodes possible given the depth of the tree, plus 1 + val maxNumNodes_p1 = Node.maxNodesInLevel(maxDepth + 1) // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodes) + val parentImpurities = new Array[Double](maxNumNodes_p1) // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodes) - // TODO: DO THIS OPTIMIZATION: - // val nodesInTree = Array.fill[Boolean](maxNumNodes)(false) // put into nodes array later? - // nodesInTree(0) = true + val nodes = new Array[Node](maxNumNodes_p1) // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - // TODO: Calculate numElementsPerNode in metadata (more precisely) + // TODO: Calculate memory usage more precisely. val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) @@ -152,17 +134,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = Node.maxNodesInLevel(level) - 1 + val levelNodeIndexOffset = Node.maxNodesInSubtree(level - 1) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - /*println(s"splitsStatsForLevel: index=$index") - println(s"\t split: ${nodeSplitStats._1}") - println(s"\t gain stats: ${nodeSplitStats._2}")*/ - val nodeIndex = levelNodeIndexOffset + index - val isLeftChild = Node.isLeftChild(nodeIndex) - val parentNodeIndex = Node.parentIndex(nodeIndex) // -1 for root node - - // if (level == 0 || (nodesInTree(parentNodeIndex) && !nodes(parentNodeIndex).isLeaf)) - // TODO: Use above check to skip unused branch of tree + val nodeIndex = levelNodeIndexOffset + index + 1 // + 1 since nodes indexed from 1 // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") @@ -176,7 +150,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo if (level != 0) { // Set parent. - if (isLeftChild) { + val parentNodeIndex = Node.parentIndex(nodeIndex) + if (Node.isLeftChild(nodeIndex)) { nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) } else { nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) @@ -184,9 +159,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } // Extract info for nodes at the next lower level. timer.start("extractInfoForLowerLevels") - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) + if (level < maxDepth) { + val leftChildIndex = Node.leftChildIndex(nodeIndex) + val leftImpurity = stats.leftImpurity + logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) + parentImpurities(leftChildIndex) = leftImpurity + + val rightChildIndex = Node.rightChildIndex(nodeIndex) + val rightImpurity = stats.rightImpurity + logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) + parentImpurities(rightChildIndex) = rightImpurity + } timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + nodeSplitStats._1) + logDebug("final best split = " + split) } require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. @@ -204,7 +189,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Initialize the top or root node of the tree. - val topNode = nodes(0) + val topNode = nodes(1) // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) @@ -212,32 +197,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + println(s"$timer") new DecisionTreeModel(topNode, strategy.algo) } - /** - * Extract the decision tree node information for the children of the node - */ - private def extractInfoForLowerLevels( - level: Int, - index: Int, - maxDepth: Int, - nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double]): Unit = { - if (level >= maxDepth) { - return - } - val leftNodeIndex = Node.maxNodesInSubtree(level) + 2 * index - val leftImpurity = nodeSplitStats._2.leftImpurity - logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) - parentImpurities(leftNodeIndex) = leftImpurity - - val rightNodeIndex = leftNodeIndex + 1 - val rightImpurity = nodeSplitStats._2.rightImpurity - logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) - parentImpurities(rightNodeIndex) = rightImpurity - } } @@ -572,6 +536,9 @@ object DecisionTree extends Serializable with Logging { * * @return Leaf index if the data point reaches a leaf. * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * set of nodes in a (level, group). */ def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { if (node.isLeaf) { @@ -600,9 +567,9 @@ object DecisionTree extends Serializable with Logging { if (node.leftNode.isEmpty || node.rightNode.isEmpty) { // Return index from next layer of nodes to train if (splitLeft) { - node.id * 2 + 1 // left + Node.leftChildIndex(node.id) } else { - node.id * 2 + 2 // right + Node.rightChildIndex(node.id) } } else { if (splitLeft) { @@ -615,7 +582,7 @@ object DecisionTree extends Serializable with Logging { } // Used for treePointToNodeIndex - val levelOffset = Node.maxNodesInLevel(level) - 1 + val levelOffset = Node.maxNodesInSubtree(level - 1) /** * Find the node index for the given example. @@ -626,9 +593,12 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { 0 } else { - val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) + val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures) // Get index for this (level, group). - globalNodeIndex - levelOffset - groupShift + // - levelOffset corrects for nodes before this level. + // - groupShift corrects for groups in this level before the current group. + // - 1 corrects for the fact that globalNodeIndex starts at 1, not 0. + globalNodeIndex - levelOffset - groupShift - 1 } } @@ -729,8 +699,8 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def binSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { val nodeIndex = treePointToNodeIndex(treePoint) // If the example does not reach this level, then nodeIndex < 0. // If the example reaches this level but is handled in a different group, @@ -777,37 +747,21 @@ object DecisionTree extends Serializable with Logging { input.treeAggregate(initAgg)(binSeqOp, binCombOp) } timer.stop("aggregation") - /* - println("binAggregates:") - for (n <- Range(0, binAggregates.size)) { - for (f <- Range(0, binAggregates(n).size)) { - for (b <- Range(0, binAggregates(n)(f).size)) { - println(s" ($n, $f, $b): ${binAggregates(n)(f)(b)}") - } - } - } - */ // Calculate best splits for all nodes at a given level timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - val nodeIndexOffset = Node.maxNodesInLevel(level) - 1 + val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1 // Iterating over all nodes at this level var nodeIndex = 0 while (nodeIndex < numNodes) { - //println(s" HANDLING node $nodeIndex") - val nodeImpurityIndex = nodeIndexOffset + nodeIndex + groupShift - //val binsForNode: Array[Double] = getBinDataForNode(node) - //logDebug("nodeImpurityIndex = " + nodeImpurityIndex) - val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("parent node impurity = " + parentNodeImpurity) + val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) + logDebug("node impurity = " + nodeImpurity) val (bestFeatureIndex, bestSplitIndex, bestGain) = - binsToBestSplit(binAggregates(nodeIndex), parentNodeImpurity, level, metadata) + binsToBestSplit(binAggregates(nodeIndex), nodeImpurity, level, metadata) bestSplits(nodeIndex) = (splits(bestFeatureIndex)(bestSplitIndex), bestGain) logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) - //println(s"bestSplits(node:$node): ${bestSplits(node)}") - nodeIndex += 1 } timer.stop("chooseSplits") @@ -823,11 +777,11 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: ImpurityAggregator, - rightNodeAgg: ImpurityAggregator, - topImpurity: Double, - level: Int, - metadata: DecisionTreeMetadata): InformationGainStats = { + leftNodeAgg: ImpurityAggregator, + rightNodeAgg: ImpurityAggregator, + topImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftNodeAgg.count val rightCount = rightNodeAgg.count @@ -835,7 +789,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount if (totalCount == 0) { // Return arbitrary prediction. - //println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0") return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) } @@ -851,15 +804,9 @@ object DecisionTree extends Serializable with Logging { val predict = parentNodeAgg.predict val prob = parentNodeAgg.prob(predict) - val leftImpurity = leftNodeAgg.calculate() // Note: 0 if count = 0 + val leftImpurity = leftNodeAgg.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightNodeAgg.calculate() - /* - println(s"calculateGainForSplit") - println(s"\t leftImpurity = $leftImpurity, leftNodeAgg: $leftNodeAgg") - println(s"\t rightImpurity = $rightImpurity, rightNodeAgg: $rightNodeAgg") - */ - val leftWeight = leftCount / totalCount.toDouble val rightWeight = rightCount / totalCount.toDouble @@ -914,8 +861,9 @@ object DecisionTree extends Serializable with Logging { * TODO: Extract in-place. */ def extractLeftRightNodeAggregates( - nodeAggregates: Array[Array[ImpurityAggregator]], - metadata: DecisionTreeMetadata): (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { + nodeAggregates: Array[Array[ImpurityAggregator]], + metadata: DecisionTreeMetadata): + (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { val numClasses = metadata.numClasses val numFeatures = metadata.numFeatures @@ -926,10 +874,10 @@ object DecisionTree extends Serializable with Logging { * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value */ def findAggForUnorderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { // TODO: Don't pass in featureIndex; use index before call. // Note: numBins = numSplits for unordered features. val numBins = metadata.numBins(featureIndex) @@ -953,10 +901,10 @@ object DecisionTree extends Serializable with Logging { * TODO: We could avoid doing one of these cumulative sums. */ def findAggForOrderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { + binData: Array[Array[ImpurityAggregator]], + leftNodeAgg: Array[Array[ImpurityAggregator]], + rightNodeAgg: Array[Array[ImpurityAggregator]], + featureIndex: Int) { // TODO: Don't pass in featureIndex; use index before call. val numSplits = metadata.numSplits(featureIndex) @@ -983,13 +931,7 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } } else { // ordered categorical feature - /* TODO: This is a temp fix. - * Eventually, for ordered categorical features, change splits and bins to be - * for individual categories instead of running totals over a pre-defined category - * ordering. Then, we could choose the ordering in this function, tailoring it - * to this particular node. - */ - var splitIndex = 0 + var splitIndex = 0 while (splitIndex < numSplits) { // no need to clone since no cumulative sum is needed leftNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex) @@ -1028,20 +970,12 @@ object DecisionTree extends Serializable with Logging { * @return tuple (best feature index, best split index, information gain) */ def binsToBestSplit( - nodeAggregates: Array[Array[ImpurityAggregator]], - nodeImpurity: Double, - level: Int, - metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { + nodeAggregates: Array[Array[ImpurityAggregator]], + nodeImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) - /* - println("nodeAggregates") - for (f <- Range(0, nodeAggregates.size)) { - for (b <- Range(0, nodeAggregates(f).size)) { - println(s"nodeAggregates($f)($b): ${nodeAggregates(f)(b)}") - } - } - */ // Extract left right node aggregates. val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(nodeAggregates, metadata) @@ -1067,7 +1001,6 @@ object DecisionTree extends Serializable with Logging { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex - //println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats") } splitIndex += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index a5e19eeef1a85..4fca7a4e4eb98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -105,11 +105,11 @@ private[tree] object DecisionTreeMetadata { numBins(f) = numUnorderedBins(k) } else { // TODO: Check the below k <= maxBins. - // This used to be k < maxPossibleBins, but <= should work. + // Checking k <= maxPossibleBins should work. // However, there may have been a 1-off error later on allocating 1 extra // (unused) bin. // TODO: Allow this case, where we simply will know nothing about some categories? - require(k < maxPossibleBins, + require(k <= maxPossibleBins, s"maxBins (= $maxPossibleBins) should be greater than max categories " + s"in categorical features (>= $k)") numBins(f) = k @@ -117,9 +117,9 @@ private[tree] object DecisionTreeMetadata { } } else { strategy.categoricalFeaturesInfo.foreach { case (f, k) => - require(k < maxPossibleBins, - s"maxBins (= $maxPossibleBins) should be greater than max categories " + - s"in categorical features (>= $k)") + require(k <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + + s"in categorical features (= ${strategy.categoricalFeaturesInfo.values.max})") numBins(f) = k } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5ed8722b534af..43023f31e0286 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector /** * :: DeveloperApi :: - * Node in a decision tree - * @param id integer node id + * Node in a decision tree. + * + * About node indexing: + * Nodes are indexed from 1. Node 1 is the root; nodes 2,3 are the left,right children. + * Node index 0 is not used. + * + * @param id integer node id, from 1 * @param predict predicted value at the node * @param isLeaf whether the leaf is a node * @param split split to calculate left and right nodes @@ -51,16 +56,13 @@ class Node ( * @param nodes array of nodes */ def build(nodes: Array[Node]): Unit = { - logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { - val leftNodeIndex = id * 2 + 1 - val rightNodeIndex = id * 2 + 2 - leftNode = Some(nodes(leftNodeIndex)) - rightNode = Some(nodes(rightNodeIndex)) + leftNode = Some(nodes(Node.leftChildIndex(id))) + rightNode = Some(nodes(Node.rightChildIndex(id))) leftNode.get.build(nodes) rightNode.get.build(nodes) } @@ -95,24 +97,20 @@ class Node ( * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. */ - private[tree] def numDescendants: Int = { - if (isLeaf) { - 0 - } else { - 2 + leftNode.get.numDescendants + rightNode.get.numDescendants - } + private[tree] def numDescendants: Int = if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants } /** * Get depth of tree from this node. * E.g.: Depth 0 means this is a leaf node. */ - private[tree] def subtreeDepth: Int = { - if (isLeaf) { - 0 - } else { - 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) - } + private[tree] def subtreeDepth: Int = if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) } /** @@ -151,30 +149,35 @@ class Node ( private[tree] object Node { /** - * Return the level of a tree which the given node is in. + * Return the index of the left child of this node. */ - def indexToLevel(nodeIndex: Int): Int = { - math.floor(math.log(nodeIndex + 1) / math.log(2)).toInt - } + def leftChildIndex(nodeIndex: Int): Int = nodeIndex * 2 /** - * Returns true if this is a left child. - * Note: Returns false for the root. + * Return the index of the right child of this node. */ - def isLeftChild(nodeIndex: Int): Boolean = nodeIndex != 0 && nodeIndex % 2 == 1 + def rightChildIndex(nodeIndex: Int): Int = nodeIndex * 2 + 1 /** - * Get the parent index of the given node, or -1 if it is the root. + * Get the parent index of the given node, or 0 if it is the root. */ - def parentIndex(nodeIndex: Int): Int = { - if (isLeftChild(nodeIndex)) { // -1 for root node - (nodeIndex - 1) / 2 - } else { - (nodeIndex - 2) / 2 - } + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) } + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + /** * Return the maximum number of nodes which can be in the given level of the tree. * @param level Level of tree (0 = root). @@ -186,6 +189,6 @@ private[tree] object Node { * (i.e., for the entire subtree from the root to this level). * @param level Level of tree (0 = root). */ - private[tree] def maxNodesInSubtree(level: Int): Int = (2 << level) - 1 + private[tree] def maxNodesInSubtree(level: Int): Int = (1 << level + 1) - 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e911dff8db830..75dea3556a403 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -190,31 +190,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0)(1).highSplit.categories === splits(0)(1).categories) assert(bins(0)(2).lowSplit.categories === splits(0)(1).categories) - /* - assert(bins(0)(2).lowSplit.categories.length === 1) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - */ - //assert(bins(0)(2).highSplit.categories === splits(0)(2).categories) assert(bins(0)(2).highSplit.categories === List(2.0, 0.0, 1.0)) assert(bins(1)(0).lowSplit.categories.length === 0) assert(bins(1)(0).highSplit.categories === splits(1)(0).categories) -/* assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0))*/ assert(bins(1)(1).lowSplit.categories === splits(1)(0).categories) assert(bins(1)(1).highSplit.categories === splits(1)(1).categories) -/* assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 1) - assert(bins(1)(1).highSplit.categories.contains(1.0))*/ assert(bins(1)(2).lowSplit.categories === splits(1)(1).categories) - //assert(bins(1)(2).highSplit.categories === splits(1)(2).categories) assert(bins(1)(2).highSplit.categories === List(2.0, 1.0, 0.0)) -/* assert(bins(1)(2).lowSplit.categories.length === 1) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) */ } test("extract categories from a number for multiclass classification") { @@ -397,7 +383,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 3) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 @@ -429,7 +415,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 @@ -483,7 +469,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -511,7 +497,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -540,7 +526,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -569,7 +555,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,12 +582,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // Train a 1-node model val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](7) - nodes(0) = modelOneNode.topNode - nodes(0).leftNode = None - nodes(0).rightNode = None + val nodes: Array[Node] = new Array[Node](8) + nodes(1) = modelOneNode.topNode + nodes(1).leftNode = None + nodes(1).rightNode = None - val parentImpurities = Array(0.5, 0.5, 0.5) + val parentImpurities = Array(0, 0.5, 0.5, 0.5) // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) @@ -643,7 +629,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -705,7 +691,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -732,7 +718,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -758,7 +744,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -780,7 +766,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -791,6 +777,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, maxBins = 10, + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(rdd, strategy) + validateClassifier(model, arr, 0.6) + } } From d7c53ee08d4a60d3a1b6ff5a6e4589c6c9698030 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 19 Aug 2014 11:32:02 -0700 Subject: [PATCH 21/34] Added more doc for ImpurityAggregator --- .../spark/mllib/tree/DecisionTree.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 37 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2d80821c79d2b..029ae6afcd130 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -954,7 +954,7 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } } else { // Regression - var featureIndex = 0 + var featureIndex = 0 while (featureIndex < numFeatures) { findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 807207d827137..e6418960d894b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -48,17 +48,35 @@ trait Impurity extends Serializable { def calculate(count: Double, sum: Double, sumSquares: Double): Double } - +/** + * This class holds a set of sufficient statistics for computing impurity from a sample. + * @param statsSize Length of the vector of sufficient statistics. + */ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializable { + /** + * Sufficient statistics for calculating impurity. + */ var counts: Array[Double] = new Array[Double](statsSize) def copy: ImpurityAggregator + /** + * Add the given label to this aggregator. + */ def add(label: Double): Unit + /** + * Compute the impurity for the samples given so far. + * If no samples have been collected, return 0. + */ def calculate(): Double + /** + * Merge another aggregator into this one, modifying this aggregator. + * @param other Aggregator of the same type. + * @return merged aggregator + */ def merge(other: ImpurityAggregator): ImpurityAggregator = { require(counts.size == other.counts.size, s"Two ImpurityAggregator instances cannot be merged with different counts sizes." + @@ -71,14 +89,31 @@ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializ this } + /** + * Number of samples added to this aggregator. + */ def count: Long + /** + * Create a new (empty) aggregator of the same type as this one. + */ def newAggregator: ImpurityAggregator + /** + * Return the prediction corresponding to the set of labels given to this aggregator. + */ def predict: Double + /** + * Return the probability of the prediction returned by [[predict]], + * or -1 if no probability is available. + */ def prob(label: Double): Double = -1 + /** + * Return the index of the largest element in this array. + * If there are ties, the first maximal element is chosen. + */ protected def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => From fd8df3063de03c5713ddf741ce47c7e91284798a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 19 Aug 2014 15:44:35 -0700 Subject: [PATCH 22/34] Moved some aggregation helpers outside of findBestSplitsPerGroup --- .../spark/mllib/tree/DecisionTree.scala | 262 +++++++++--------- 1 file changed, 127 insertions(+), 135 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 029ae6afcd130..6f45d1c32d0c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -453,6 +453,129 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * set of nodes in a (level, group). + */ + def predictNodeIndex(node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold + } + case Categorical => { + val featureValue = if (unorderedFeatures.contains(featureIndex)) { + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + Node.leftChildIndex(node.id) + } else { + Node.rightChildIndex(node.id) + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures) + } + } + } + } + + /** + * Helper for binSeqOp. + * + * @param agg Array storing aggregate calculation. + * For ordered features, this is of size: + * numClasses * numBins * numFeatures * numNodes. + * For unordered features, this is of size: + * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + */ + def someUnorderedBinSeqOp( + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint, + nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Unit = { + // Iterate over all features. + val numFeatures = treePoint.binnedFeatures.size + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures(featureIndex) + // Update the left or right count for one bin. + // Find all matching bins and increment their values. + val numCategoricalBins = bins(featureIndex).size //metadata.numBins(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { + agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) + } else { + agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label) + } + binIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures(featureIndex) + agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) + } + featureIndex += 1 + } + } + + /** + * Helper for binSeqOp: for regression and for classification with only ordered features. + * + * Performs a sequential aggregation over a partition for regression. + * For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array storing aggregate calculation, updated by this function. + * Size: 3 * numBins * numFeatures * numNodes + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @return agg + */ + def orderedBinSeqOp( + agg: Array[Array[Array[ImpurityAggregator]]], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + val numFeatures = treePoint.binnedFeatures.size + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg(nodeIndex)(featureIndex)(binIndex).add(label) + featureIndex += 1 + } + } + /** * Returns an array of optimal splits for a group of nodes at a given level * @@ -529,60 +652,8 @@ object DecisionTree extends Serializable with Logging { // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. - * - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - * Note: This is the global node index, i.e., the index used in the tree. - * This index is different from the index used during training a particular - * set of nodes in a (level, group). - */ - def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { - if (node.isLeaf) { - node.id - } else { - val featureIndex = node.split.get.feature - val splitLeft = node.split.get.featureType match { - case Continuous => { - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] - // We do not need to check lowSplit since bins are separated by splits. - featureValueUpperBound <= node.split.get.threshold - } - case Categorical => { - val featureValue = if (metadata.isUnordered(featureIndex)) { - binnedFeatures(featureIndex) - } else { - val binIndex = binnedFeatures(featureIndex) - bins(featureIndex)(binIndex).category - } - node.split.get.categories.contains(featureValue) - } - case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") - } - if (node.leftNode.isEmpty || node.rightNode.isEmpty) { - // Return index from next layer of nodes to train - if (splitLeft) { - Node.leftChildIndex(node.id) - } else { - Node.rightChildIndex(node.id) - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftNode.get, binnedFeatures) - } else { - predictNodeIndex(node.rightNode.get, binnedFeatures) - } - } - } - } - // Used for treePointToNodeIndex - val levelOffset = Node.maxNodesInSubtree(level - 1) + val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1 /** * Find the node index for the given example. @@ -593,90 +664,12 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { 0 } else { - val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures) + val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) // Get index for this (level, group). // - levelOffset corrects for nodes before this level. // - groupShift corrects for groups in this level before the current group. // - 1 corrects for the fact that globalNodeIndex starts at 1, not 0. - globalNodeIndex - levelOffset - groupShift - 1 - } - } - - - val rightChildShift = numClasses * numBins * numFeatures * numNodes - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation. - * For ordered features, this is of size: - * numClasses * numBins * numFeatures * numNodes. - * For unordered features, this is of size: - * 2 * numClasses * numBins * numFeatures * numNodes. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def someUnorderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - // Update the left or right count for one bin. - // Find all matching bins and increment their values. - val numCategoricalBins = metadata.numBins(featureIndex) - var binIndex = 0 - while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) - } else { - agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label) - } - binIndex += 1 - } - } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) - } - featureIndex += 1 - } - } - - /** - * Helper for binSeqOp: for regression and for classification with only ordered features. - * - * Performs a sequential aggregation over a partition for regression. - * For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. - * - * @param agg Array storing aggregate calculation, updated by this function. - * Size: 3 * numBins * numFeatures * numNodes - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg - */ - def orderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Update count, sum, and sum^2 for one bin. - val binIndex = treePoint.binnedFeatures(featureIndex) - if (binIndex >= agg(nodeIndex)(featureIndex).size) { - throw new RuntimeException( - s"binIndex: $binIndex, agg(nodeIndex)(featureIndex).size = ${agg(nodeIndex)(featureIndex).size}") - } - agg(nodeIndex)(featureIndex)(binIndex).add(label) - featureIndex += 1 + globalNodeIndex - globalNodeIndexOffset } } @@ -709,7 +702,7 @@ object DecisionTree extends Serializable with Logging { if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg, treePoint, nodeIndex) } else { - someUnorderedBinSeqOp(agg, treePoint, nodeIndex) + someUnorderedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) } } agg @@ -751,7 +744,6 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1 // Iterating over all nodes at this level var nodeIndex = 0 while (nodeIndex < numNodes) { From 92f7118876b0efad4c3227e3bd7dbc272898757b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 20 Aug 2014 15:41:09 -0700 Subject: [PATCH 23/34] Added partly written DTStatsAggregator --- .../mllib/tree/impl/DTStatsAggregator.scala | 52 +++++++++++++++++++ .../spark/mllib/tree/impurity/Gini.scala | 3 -- .../spark/mllib/tree/impurity/Impurity.scala | 1 + 3 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000000..c23e35fe3e27a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + + +/** + * :: Experimental :: + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * TODO: Allow views of Vector types to replace some of the code in here. + */ +private[tree] class DTStatsAggregator( + val numNodes: Int, + val numFeatures: Int, + val numBins: Array[Int], + val statsSize: Int) { + + require(numBins.size == numFeatures, s"DTStatsAggregator was given numBins" + + s" (of size ${numBins.size}) which did not match numFeatures = $numFeatures.") + + val featureOffsets: Array[Int] = numBins.scanLeft(0)(_ + _).map(statsSize * _) + + val allStatsSize: Int = numNodes * featureOffsets.last * statsSize + + val allStats: Array[Double] = new Array[Double](allStatsSize) + +// TODO: Make views + /* + Uses: + point access + + */ + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index e8aa4e9c7f7c1..9dab6c53cf9a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -89,9 +89,6 @@ private[tree] class GiniAggregator(numClasses: Int) throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= ${counts.size}).") } - if (label.toInt >= counts.size) { - throw new RuntimeException(s"label = $label, counts = $counts") - } counts(label.toInt) += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index e6418960d894b..12dec1b276c07 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -113,6 +113,7 @@ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializ /** * Return the index of the largest element in this array. * If there are ties, the first maximal element is chosen. + * TODO: Move this elsewhere in Spark? */ protected def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { From f2166fde2258fffb71b7f968defcb9c49fcc94a7 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 20 Aug 2014 16:56:49 -0700 Subject: [PATCH 24/34] still working on DTStatsAggregator --- .../spark/mllib/tree/DecisionTree.scala | 5 +- .../mllib/tree/impl/DTStatsAggregator.scala | 95 +++++++++++++++++-- 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6f45d1c32d0c6..87ad9fed377de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -529,9 +529,10 @@ object DecisionTree extends Serializable with Logging { val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. // Find all matching bins and increment their values. - val numCategoricalBins = bins(featureIndex).size //metadata.numBins(featureIndex) + val numCategoricalBins = bins(featureIndex).size / 2 //metadata.numBins(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { + // loop over bins, with possible offset, for fixed node, feature, label (for unordered categorical) if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) } else { @@ -541,6 +542,7 @@ object DecisionTree extends Serializable with Logging { } } else { // Ordered feature + // random access, for fixed nodeIndex val binIndex = treePoint.binnedFeatures(featureIndex) agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) } @@ -570,6 +572,7 @@ object DecisionTree extends Serializable with Logging { val numFeatures = treePoint.binnedFeatures.size var featureIndex = 0 while (featureIndex < numFeatures) { + // random access, for fixed nodeIndex val binIndex = treePoint.binnedFeatures(featureIndex) agg(nodeIndex)(featureIndex)(binIndex).add(label) featureIndex += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index c23e35fe3e27a..558f943ae2e66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -29,24 +29,103 @@ import scala.collection.mutable */ private[tree] class DTStatsAggregator( val numNodes: Int, - val numFeatures: Int, val numBins: Array[Int], + unorderedFeatures: Set[Int], val statsSize: Int) { - require(numBins.size == numFeatures, s"DTStatsAggregator was given numBins" + - s" (of size ${numBins.size}) which did not match numFeatures = $numFeatures.") + val numFeatures: Int = numBins.size - val featureOffsets: Array[Int] = numBins.scanLeft(0)(_ + _).map(statsSize * _) + val isUnordered: Array[Boolean] = + Range(0, numFeatures).map(f => unorderedFeatures.contains(f)).toArray - val allStatsSize: Int = numNodes * featureOffsets.last * statsSize + private val featureOffsets: Array[Int] = { + def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { + if (isUnordered(featureIndex)) { + total + 2 * numBins(featureIndex) + } else { + total + numBins(featureIndex) + } + } + Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + } + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last * statsSize + + /** + * Total number of elements stored in this aggregator. + */ + val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) + * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + */ val allStats: Array[Double] = new Array[Double](allStatsSize) -// TODO: Make views + /** + * Get a view of the stats for a given (node, feature, bin) for ordered features. + * @return (flat stats array, start index of stats) The stats are contiguous in the array. + */ + def view(nodeIndex: Int, featureIndex: Int, binIndex: Int): (Array[Double], Int) = { + (allStats, nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize) + } + + /** + * Pre-compute node offset for use with [[nodeView]]. + */ + def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + /** + * Get a view of the stats for a given (node, feature, bin) for ordered features. + * This uses a pre-computed node offset from [[getNodeOffset]]. + * @return (flat stats array, start index of stats) The stats are contiguous in the array. + */ + def nodeView(nodeOffset: Int, featureIndex: Int, binIndex: Int): (Array[Double], Int) = { + (allStats, nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureView]]. + */ + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = + nodeIndex * nodeStride + featureOffsets(featureIndex) + + /** + * Get a view of the stats for a given (node, feature, bin) for ordered features. + * This uses a pre-computed (node, feature) offset from [[getNodeFeatureOffset]]. + * @return (flat stats array, start index of stats) The stats are contiguous in the array. + */ + def nodeFeatureView(nodeFeatureOffset: Int, binIndex: Int): (Array[Double], Int) = { + (allStats, nodeFeatureOffset + binIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + //TODO + } + + // TODO: Make views /* - Uses: - point access + VIEWS TO MAKE: + random access + impurityAggregator.update(statsAggregator.view(nodeIndex, featureIndex, binIndex), label) + random access for fixed nodeIndex + statsAggregator.getNodeOffset(nodeIndex) = nodeIndex * nodeStride + impurityAggregator.update(statsAggregator.nodeView(nodeOffset, featureIndex, binIndex), label) + loop over bins, with rightChildOffset, for fixed node, feature (for unordered categorical) + statsAggregator.getNodeFeatureOffset(nodeIndex, featureIndex) = nodeIndex * nodeStride + featureOffsets(featureIndex) + impurityAggregator.update(statsAggregator.nodeFeatureView(nodeFeatureOffset, binIndex, isLeft), label) + complete sum */ } From 807cd00cbb4d05d2421af88302f11877160105cf Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 23 Aug 2014 18:22:34 -0700 Subject: [PATCH 25/34] Finished DTStatsAggregator, a wrapper around the aggregate statistics for easy but hopefully efficient indexing. Modified old ImpurityAggregator classes and renamed them ImpurityCalculator; added ImpurityAggregator classes which work with DTStatsAggregator but do not store data. Unit tests all succeed. --- .../spark/mllib/tree/DecisionTree.scala | 361 +++++++----------- .../mllib/tree/impl/DTStatsAggregator.scala | 137 +++++-- .../spark/mllib/tree/impurity/Entropy.scala | 56 +-- .../spark/mllib/tree/impurity/Gini.scala | 56 +-- .../spark/mllib/tree/impurity/Impurity.scala | 96 +++-- .../spark/mllib/tree/impurity/Variance.scala | 51 ++- 6 files changed, 382 insertions(+), 375 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 87ad9fed377de..39cdf9091f571 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, DTStatsAggregator, TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ @@ -81,7 +81,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) - val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree, plus 1 @@ -404,9 +403,6 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - - private val InvalidBinIndex = -1 - /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. @@ -517,34 +513,35 @@ object DecisionTree extends Serializable with Logging { * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def someUnorderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint, - nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Unit = { + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Unit = { // Iterate over all features. val numFeatures = treePoint.binnedFeatures.size + val nodeOffset = agg.getNodeOffset(nodeIndex) var featureIndex = 0 while (featureIndex < numFeatures) { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) + val (leftNodeFeatureOffset, rightNodeFeatureOffset) = + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) // Update the left or right count for one bin. // Find all matching bins and increment their values. - val numCategoricalBins = bins(featureIndex).size / 2 //metadata.numBins(featureIndex) + val numCategoricalBins = agg.numBins(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { - // loop over bins, with possible offset, for fixed node, feature, label (for unordered categorical) if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) + agg.nodeFeatureUpdate(leftNodeFeatureOffset, binIndex, treePoint.label) } else { - agg(nodeIndex)(featureIndex)(numCategoricalBins + binIndex).add(treePoint.label) + agg.nodeFeatureUpdate(rightNodeFeatureOffset, binIndex, treePoint.label) } binIndex += 1 } } else { // Ordered feature - // random access, for fixed nodeIndex val binIndex = treePoint.binnedFeatures(featureIndex) - agg(nodeIndex)(featureIndex)(binIndex).add(treePoint.label) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) } featureIndex += 1 } @@ -564,17 +561,17 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def orderedBinSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint, - nodeIndex: Int): Unit = { + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int): Unit = { val label = treePoint.label + val nodeOffset = agg.getNodeOffset(nodeIndex) // Iterate over all features. val numFeatures = treePoint.binnedFeatures.size var featureIndex = 0 while (featureIndex < numFeatures) { - // random access, for fixed nodeIndex val binIndex = treePoint.binnedFeatures(featureIndex) - agg(nodeIndex)(featureIndex)(binIndex).add(label) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) featureIndex += 1 } } @@ -639,10 +636,6 @@ object DecisionTree extends Serializable with Logging { val numFeatures = metadata.numFeatures logDebug("numFeatures = " + numFeatures) - // numBins: Number of bins = 1 + number of possible splits - val numBins = bins(0).length - logDebug("numBins = " + numBins) - val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) @@ -695,8 +688,8 @@ object DecisionTree extends Serializable with Logging { * @return agg */ def binSeqOp( - agg: Array[Array[Array[ImpurityAggregator]]], - treePoint: TreePoint): Array[Array[Array[ImpurityAggregator]]] = { + agg: DTStatsAggregator, + treePoint: TreePoint): DTStatsAggregator = { val nodeIndex = treePointToNodeIndex(treePoint) // If the example does not reach this level, then nodeIndex < 0. // If the example reaches this level but is handled in a different group, @@ -711,39 +704,32 @@ object DecisionTree extends Serializable with Logging { agg } - /** - * Combines the aggregates from partitions. - * @param agg1 Array containing aggregates from one or more partitions - * @param agg2 Array containing aggregates from one or more partitions - * @return Combined aggregate from agg1 and agg2 - */ - def binCombOp( - agg1: Array[Array[Array[ImpurityAggregator]]], - agg2: Array[Array[Array[ImpurityAggregator]]]): Array[Array[Array[ImpurityAggregator]]] = { - var n = 0 - while (n < agg2.size) { - var f = 0 - while (f < agg2(n).size) { - var b = 0 - while (b < agg2(n)(f).size) { - agg1(n)(f)(b).merge(agg2(n)(f)(b)) - b += 1 - } - f += 1 - } - n += 1 - } - agg1 - } - // Calculate bin aggregates. timer.start("aggregation") - val binAggregates = { - val initAgg = getEmptyBinAggregates(metadata, numNodes) - input.treeAggregate(initAgg)(binSeqOp, binCombOp) + val binAggregates: DTStatsAggregator = { + val initAgg = new DTStatsAggregator(metadata, numNodes) + input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") + /* + println("binAggregates for unordered cats:") + for (n <- Range(0, numNodes)) { + for (f <- Range(0, numFeatures)) { + val numBins = metadata.numBins(f) + for (b <- Range(0, numBins)) { + val (leftOffset, rightOffset) = binAggregates.getLeftRightNodeFeatureOffsets(n, f) + val leftCalc = binAggregates.getImpurityCalculator(leftOffset, b) + val rightCalc = binAggregates.getImpurityCalculator(rightOffset, b) + println(s"\t bin(n:$n)(f:$f)(b:$b)(left): $leftCalc") + println(s"\t bin(n:$n)(f:$f)(b:$b)(right): $rightCalc") + } + } + } + println("binAggregates, flat array:") + binAggregates.allStats.foreach(x => println(s"\t $x")) + */ + // Calculate best splits for all nodes at a given level timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) @@ -753,7 +739,7 @@ object DecisionTree extends Serializable with Logging { val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) logDebug("node impurity = " + nodeImpurity) val (bestFeatureIndex, bestSplitIndex, bestGain) = - binsToBestSplit(binAggregates(nodeIndex), nodeImpurity, level, metadata) + binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata) bestSplits(nodeIndex) = (splits(bestFeatureIndex)(bestSplitIndex), bestGain) logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) @@ -766,20 +752,20 @@ object DecisionTree extends Serializable with Logging { /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftNodeAgg left node aggregates for this (feature, split) - * @param rightNodeAgg right node aggregate for this (feature, split) + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: ImpurityAggregator, - rightNodeAgg: ImpurityAggregator, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { - val leftCount = leftNodeAgg.count - val rightCount = rightNodeAgg.count + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count val totalCount = leftCount + rightCount if (totalCount == 0) { @@ -787,8 +773,8 @@ object DecisionTree extends Serializable with Logging { return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) } - val parentNodeAgg = leftNodeAgg.copy - parentNodeAgg.merge(rightNodeAgg) + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) // impurity of parent node val impurity = if (level > 0) { topImpurity @@ -799,8 +785,8 @@ object DecisionTree extends Serializable with Logging { val predict = parentNodeAgg.predict val prob = parentNodeAgg.prob(predict) - val leftImpurity = leftNodeAgg.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightNodeAgg.calculate() + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() val leftWeight = leftCount / totalCount.toDouble val rightWeight = rightCount / totalCount.toDouble @@ -811,187 +797,97 @@ object DecisionTree extends Serializable with Logging { } /** - * Calculates information gain for all nodes splits. - * @param leftNodeAgg Aggregate stats, of dimensions (numFeatures, numSplits(feature)) - * @param rightNodeAgg Aggregate stats, of dimensions (numFeatures, numSplits(feature)) - * @param nodeImpurity Impurity for node being split. - * @return Info gain, of dimensions (numFeatures, numSplits(feature)) + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @param nodeIndex Index for node to split in this (level, group). + * @param nodeImpurity Impurity of the node (nodeIndex). + * @return tuple (best feature index, best split index, information gain) */ - def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], + def binsToBestSplit( + binAggregates: DTStatsAggregator, + nodeIndex: Int, nodeImpurity: Double, level: Int, - metadata: DecisionTreeMetadata): Array[Array[InformationGainStats]] = { - val gains = new Array[Array[InformationGainStats]](metadata.numFeatures) + metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { + + logDebug("node impurity = " + nodeImpurity) + // For each (feature, split), calculate the gain, and select the best (feature, split). + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) var featureIndex = 0 + // TODO: Change loops over splits into iterators. while (featureIndex < metadata.numFeatures) { - val numSplitsForFeature = metadata.numSplits(featureIndex) - gains(featureIndex) = new Array[InformationGainStats](numSplitsForFeature) - var splitIndex = 0 - while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = - calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), - rightNodeAgg(featureIndex)(splitIndex), nodeImpurity, level, metadata) - splitIndex += 1 - } - featureIndex += 1 - } - gains - } - - /** - * Extracts left and right split aggregates. - * @param binData Aggregate array slice from getBinDataForNode. - * For classification: - * For unordered features, this is leftChildData ++ rightChildData, - * each of which is indexed by (feature, split/bin, class), - * with class being the least significant bit. - * For ordered features, this is of size numClasses * numBins * numFeatures. - * For regression: - * This is of size 2 * numFeatures * numBins. - * @return (leftNodeAgg, rightNodeAgg) pair of arrays. - * Each array is of size (numFeatures, numSplits(feature)). - * TODO: Extract in-place. - */ - def extractLeftRightNodeAggregates( - nodeAggregates: Array[Array[ImpurityAggregator]], - metadata: DecisionTreeMetadata): - (Array[Array[ImpurityAggregator]], Array[Array[ImpurityAggregator]]) = { - - val numClasses = metadata.numClasses - val numFeatures = metadata.numFeatures - - /** - * Reshape binData for this feature. - * Indexes binData as (feature, split, class) with class as the least significant bit. - * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value - */ - def findAggForUnorderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { - // TODO: Don't pass in featureIndex; use index before call. - // Note: numBins = numSplits for unordered features. - val numBins = metadata.numBins(featureIndex) - leftNodeAgg(featureIndex) = binData(featureIndex).slice(0, numBins) - rightNodeAgg(featureIndex) = binData(featureIndex).slice(numBins, 2 * numBins) - } - - /** - * For ordered features (regression and classification with ordered features). - * The input binData is indexed as (feature, bin, class). - * This computes cumulative sums over splits. - * Each (feature, class) pair is handled separately. - * TODO: UPDATE DOC: Note: numSplits = numBins - 1. - * @param leftNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 0, ..., numSplits - 2) is set to be - * the cumulative sum (from left) over binData for bins 0, ..., i. - * @param rightNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 1, ..., numSplits - 1) is set to be - * the cumulative sum (from right) over binData for bins - * numBins - 1, ..., numBins - 1 - i. - * TODO: We could avoid doing one of these cumulative sums. - */ - def findAggForOrderedFeature( - binData: Array[Array[ImpurityAggregator]], - leftNodeAgg: Array[Array[ImpurityAggregator]], - rightNodeAgg: Array[Array[ImpurityAggregator]], - featureIndex: Int) { - - // TODO: Don't pass in featureIndex; use index before call. val numSplits = metadata.numSplits(featureIndex) - leftNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) - rightNodeAgg(featureIndex) = new Array[ImpurityAggregator](numSplits) - if (metadata.isContinuous(featureIndex)) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(featureIndex)(0).copy - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numSplits - 1) = binData(featureIndex)(numSplits).copy - - // Iterate over all splits. - var splitIndex = 1 + //println(s"binsToBestSplit: feature $featureIndex (continuous)") + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + var splitIndex = 0 while (splitIndex < numSplits) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex) = leftNodeAgg(featureIndex)(splitIndex - 1).copy - leftNodeAgg(featureIndex)(splitIndex).merge(binData(featureIndex)(splitIndex)) - rightNodeAgg(featureIndex)(numSplits - 1 - splitIndex) = - rightNodeAgg(featureIndex)(numSplits - splitIndex).copy - rightNodeAgg(featureIndex)(numSplits - 1 - splitIndex).merge( - binData(featureIndex)(numSplits - splitIndex)) + binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } - } else { // ordered categorical feature - var splitIndex = 0 + // Find best split. + splitIndex = 0 while (splitIndex < numSplits) { - // no need to clone since no cumulative sum is needed - leftNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex) - rightNodeAgg(featureIndex)(splitIndex) = binData(featureIndex)(splitIndex + 1) + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } splitIndex += 1 } - } - } - - val leftNodeAgg = new Array[Array[ImpurityAggregator]](numFeatures) - val rightNodeAgg = new Array[Array[ImpurityAggregator]](numFeatures) - if (metadata.isClassification) { - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - findAggForUnorderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) + } else if (metadata.isUnordered(featureIndex)) { + //println(s"binsToBestSplit: feature $featureIndex (unordered cat)") + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + //println(s"\t split $splitIndex: gain: ${bestGainStats.gain}") + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } + splitIndex += 1 } - featureIndex += 1 - } - } else { // Regression - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForOrderedFeature(nodeAggregates, leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - } - (leftNodeAgg, rightNodeAgg) - } - - /** - * Find the best split for a node. - * @param binData Bin data slice for this node, given by getBinDataForNode. - * @param nodeImpurity impurity of the top node - * @return tuple (best feature index, best split index, information gain) - */ - def binsToBestSplit( - nodeAggregates: Array[Array[ImpurityAggregator]], - nodeImpurity: Double, - level: Int, - metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { - - logDebug("node impurity = " + nodeImpurity) - - // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(nodeAggregates, metadata) - - // Calculate gains for all splits. - val gains = - calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity, level, metadata) - - val (bestFeatureIndex, bestSplitIndex, gainStats) = { - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - // Iterate over features. - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - // Iterate over all splits. + } else { + //println(s"binsToBestSplit: feature $featureIndex (ordered cat)") + // Ordered categorical feature + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + // TODO: Choose adaptive ordering for ordered categorical features, and compute cumulative sum. + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) var splitIndex = 0 - val numSplitsForFeature = metadata.numSplits(featureIndex) - while (splitIndex < numSplitsForFeature) { - val gainStats = gains(featureIndex)(splitIndex) + while (splitIndex < numSplits) { + binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + splitIndex += 1 + } + // Find best split. + splitIndex = 0 + while (splitIndex < numSplits) { + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + //println(s"\t split $splitIndex: gain: ${bestGainStats.gain}") if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex @@ -999,11 +895,10 @@ object DecisionTree extends Serializable with Logging { } splitIndex += 1 } - featureIndex += 1 } - (bestFeatureIndex, bestSplitIndex, bestGainStats) + featureIndex += 1 } - (bestFeatureIndex, bestSplitIndex, gainStats) + (bestFeatureIndex, bestSplitIndex, bestGainStats) } /** @@ -1029,6 +924,7 @@ object DecisionTree extends Serializable with Logging { * For unordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(2 * binIndex), * where the bins are ordered as (numBins left bins, numBins right bins). */ + /* private def getEmptyBinAggregates( metadata: DecisionTreeMetadata, numNodes: Int): Array[Array[Array[ImpurityAggregator]]] = { @@ -1059,6 +955,7 @@ object DecisionTree extends Serializable with Logging { } agg } + */ /** * Returns splits and bins for decision tree calculation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 558f943ae2e66..70a1c8e6dff24 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.tree.impl +import org.apache.spark.mllib.tree.impurity._ + import scala.collection.mutable @@ -28,15 +30,24 @@ import scala.collection.mutable * TODO: Allow views of Vector types to replace some of the code in here. */ private[tree] class DTStatsAggregator( - val numNodes: Int, - val numBins: Array[Int], - unorderedFeatures: Set[Int], - val statsSize: Int) { + metadata: DecisionTreeMetadata, + val numNodes: Int) extends Serializable { + + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + val statsSize: Int = impurityAggregator.statsSize - val numFeatures: Int = numBins.size + val numFeatures: Int = metadata.numFeatures + + val numBins: Array[Int] = metadata.numBins val isUnordered: Array[Boolean] = - Range(0, numFeatures).map(f => unorderedFeatures.contains(f)).toArray + Range(0, numFeatures).map(f => metadata.unorderedFeatures.contains(f)).toArray private val featureOffsets: Array[Int] = { def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { @@ -69,40 +80,87 @@ private[tree] class DTStatsAggregator( val allStats: Array[Double] = new Array[Double](allStatsSize) /** - * Get a view of the stats for a given (node, feature, bin) for ordered features. - * @return (flat stats array, start index of stats) The stats are contiguous in the array. + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) + } + + /** + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. */ - def view(nodeIndex: Int, featureIndex: Int, binIndex: Int): (Array[Double], Int) = { - (allStats, nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize) + def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) } /** - * Pre-compute node offset for use with [[nodeView]]. + * Pre-compute node offset for use with [[nodeUpdate]]. */ def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride /** - * Get a view of the stats for a given (node, feature, bin) for ordered features. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. * This uses a pre-computed node offset from [[getNodeOffset]]. - * @return (flat stats array, start index of stats) The stats are contiguous in the array. */ - def nodeView(nodeOffset: Int, featureIndex: Int, binIndex: Int): (Array[Double], Int) = { - (allStats, nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize) + def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) } /** - * Pre-compute (node, feature) offset for use with [[nodeFeatureView]]. + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + require(!isUnordered(featureIndex), + s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + + s" for unordered feature $featureIndex.") nodeIndex * nodeStride + featureOffsets(featureIndex) + } /** - * Get a view of the stats for a given (node, feature, bin) for ordered features. - * This uses a pre-computed (node, feature) offset from [[getNodeFeatureOffset]]. - * @return (flat stats array, start index of stats) The stats are contiguous in the array. + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For unordered features only. */ - def nodeFeatureView(nodeFeatureOffset: Int, binIndex: Int): (Array[Double], Int) = { - (allStats, nodeFeatureOffset + binIndex * statsSize) + def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { + require(isUnordered(featureIndex), + s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + + s" but was called for ordered feature $featureIndex.") + val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) + (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + } + + /** + * Update the stats for a given (node, feature, bin), using the given label. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + } + + /** + * For a given (node, feature), merge the stats for two bins. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, + nodeFeatureOffset + otherBinIndex * statsSize) } /** @@ -110,22 +168,29 @@ private[tree] class DTStatsAggregator( * This method modifies this aggregator in-place. */ def merge(other: DTStatsAggregator): DTStatsAggregator = { - //TODO + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + this } - // TODO: Make views - /* - VIEWS TO MAKE: - random access - impurityAggregator.update(statsAggregator.view(nodeIndex, featureIndex, binIndex), label) - random access for fixed nodeIndex - statsAggregator.getNodeOffset(nodeIndex) = nodeIndex * nodeStride - impurityAggregator.update(statsAggregator.nodeView(nodeOffset, featureIndex, binIndex), label) - loop over bins, with rightChildOffset, for fixed node, feature (for unordered categorical) - statsAggregator.getNodeFeatureOffset(nodeIndex, featureIndex) = nodeIndex * nodeStride + featureOffsets(featureIndex) - impurityAggregator.update(statsAggregator.nodeFeatureView(nodeFeatureOffset, binIndex, isLeft), label) - - complete sum +} + +private[tree] object DTStatsAggregator extends Serializable { + + /** + * Combines two aggregates (modifying the first) and returns the combination. */ + def binCombOp( + agg1: DTStatsAggregator, + agg2: DTStatsAggregator): DTStatsAggregator = { + agg1.merge(agg2) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 77d64b69c39c7..e1667d474f676 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -78,50 +78,56 @@ object Entropy extends Impurity { private[tree] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { - def calculate(): Double = { - Entropy.calculate(counts, counts.sum) + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 } - def copy: EntropyAggregator = { - val tmp = new EntropyAggregator(counts.size) - tmp.counts = this.counts.clone() - tmp + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { + new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } - def add(label: Double): Unit = { - if (label >= counts.size) { - throw new IllegalArgumentException(s"EntropyAggregator given label $label" + - s" but requires label < numClasses (= ${counts.size}).") - } - counts(label.toInt) += 1 - } +} + +private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) - def count: Long = counts.sum.toLong + def calculate(): Double = Entropy.calculate(stats, stats.sum) + + def count: Long = stats.sum.toLong def predict: Double = if (count == 0) { 0 } else { - indexOfLargestArrayElement(counts) + indexOfLargestArrayElement(stats) } override def prob(label: Double): Double = { val lbl = label.toInt - require(lbl < counts.length, - s"EntropyAggregator.prob given invalid label: $lbl (should be < ${counts.length}") + require(lbl < stats.length, + s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") val cnt = count if (cnt == 0) { 0 } else { - counts(lbl) / cnt + stats(lbl) / cnt } } - override def toString: String = { - s"EntropyAggregator(counts = [${counts.mkString(", ")}])" - } - - def newAggregator: EntropyAggregator = { - new EntropyAggregator(counts.size) - } + override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 9dab6c53cf9a1..d2b3fe3df576d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -74,50 +74,56 @@ object Gini extends Impurity { private[tree] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { - def calculate(): Double = { - Gini.calculate(counts, counts.sum) + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 } - def copy: GiniAggregator = { - val tmp = new GiniAggregator(counts.size) - tmp.counts = this.counts.clone() - tmp + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { + new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } - def add(label: Double): Unit = { - if (label >= counts.size) { - throw new IllegalArgumentException(s"GiniAggregator given label $label" + - s" but requires label < numClasses (= ${counts.size}).") - } - counts(label.toInt) += 1 - } +} + +private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + def copy: GiniCalculator = new GiniCalculator(stats.clone()) - def count: Long = counts.sum.toLong + def calculate(): Double = Gini.calculate(stats, stats.sum) + + def count: Long = stats.sum.toLong def predict: Double = if (count == 0) { 0 } else { - indexOfLargestArrayElement(counts) + indexOfLargestArrayElement(stats) } override def prob(label: Double): Double = { val lbl = label.toInt - require(lbl < counts.length, - s"GiniAggregator.prob given invalid label: $lbl (should be < ${counts.length}") + require(lbl < stats.length, + s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") val cnt = count if (cnt == 0) { 0 } else { - counts(lbl) / cnt + stats(lbl) / cnt } } - override def toString: String = { - s"GiniAggregator(counts = [${counts.mkString(", ")}])" - } - - def newAggregator: GiniAggregator = { - new GiniAggregator(counts.size) - } + override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 12dec1b276c07..2954679ea4546 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -49,72 +49,92 @@ trait Impurity extends Serializable { } /** - * This class holds a set of sufficient statistics for computing impurity from a sample. + * Interface for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data itself. * @param statsSize Length of the vector of sufficient statistics. */ -private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializable { +private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { /** - * Sufficient statistics for calculating impurity. + * Merge the stats from one bin into another. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for (node, feature, bin) which is modified by the merge. + * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified. */ - var counts: Array[Double] = new Array[Double](statsSize) - - def copy: ImpurityAggregator + def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = { + var i = 0 + while (i < statsSize) { + allStats(offset + i) += allStats(otherOffset + i) + i += 1 + } + } /** - * Add the given label to this aggregator. + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). */ - def add(label: Double): Unit + def update(allStats: Array[Double], offset: Int, label: Double): Unit /** - * Compute the impurity for the samples given so far. - * If no samples have been collected, return 0. + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). */ + def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[ImpurityAggregator]], this class stores its own data and is for a single + * (node, feature, bin). + * @param stats Array of sufficient statistics. + */ +private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { + + def copy: ImpurityCalculator + def calculate(): Double /** - * Merge another aggregator into this one, modifying this aggregator. - * @param other Aggregator of the same type. - * @return merged aggregator + * Add the stats from another calculator into this one, modifying and returning this calculator. */ - def merge(other: ImpurityAggregator): ImpurityAggregator = { - require(counts.size == other.counts.size, - s"Two ImpurityAggregator instances cannot be merged with different counts sizes." + - s" Sizes are ${counts.size} and ${other.counts.size}.") + def add(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be added with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") var i = 0 - while (i < other.counts.size) { - counts(i) += other.counts(i) + while (i < other.stats.size) { + stats(i) += other.stats(i) i += 1 } this } /** - * Number of samples added to this aggregator. + * Subtract the stats from another calculator from this one, modifying and returning this + * calculator. */ - def count: Long + def subtract(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") + var i = 0 + while (i < other.stats.size) { + stats(i) -= other.stats(i) + i += 1 + } + this + } - /** - * Create a new (empty) aggregator of the same type as this one. - */ - def newAggregator: ImpurityAggregator + def count: Long - /** - * Return the prediction corresponding to the set of labels given to this aggregator. - */ def predict: Double - /** - * Return the probability of the prediction returned by [[predict]], - * or -1 if no probability is available. - */ def prob(label: Double): Double = -1 - /** - * Return the index of the largest element in this array. - * If there are ties, the first maximal element is chosen. - * TODO: Move this elsewhere in Spark? - */ protected def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => @@ -125,7 +145,7 @@ private[tree] abstract class ImpurityAggregator(statsSize: Int) extends Serializ } } if (result._1 < 0) { - throw new RuntimeException("ImpurityAggregator internal error:" + + throw new RuntimeException("ImpurityCalculator internal error:" + " indexOfLargestArrayElement failed") } result._1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 63030cc2de5d9..0386db0a7d422 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -62,38 +62,51 @@ object Variance extends Impurity { } -private[tree] class VarianceAggregator extends ImpurityAggregator(3) with Serializable { +private[tree] class VarianceAggregator() + extends ImpurityAggregator(statsSize = 3) with Serializable { - def calculate(): Double = { - Variance.calculate(counts(0), counts(1), counts(2)) + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + allStats(offset) += 1 + allStats(offset + 1) += label + allStats(offset + 2) += label * label } - def copy: VarianceAggregator = { - val tmp = new VarianceAggregator() - tmp.counts = this.counts.clone() - tmp + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { + new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } - def add(label: Double): Unit = { - counts(0) += 1 - counts(1) += label - counts(2) += label * label - } +} - def count: Long = counts(0).toLong +private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + require(stats.size == 3, + s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + + s" but was given array of length ${stats.size}.") + + def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) + + def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) + + def count: Long = stats(0).toLong def predict: Double = if (count == 0) { 0 } else { - counts(1) / count + stats(1) / count } override def toString: String = { - s"VarianceAggregator(cnt = ${counts(0)}, sum = ${counts(1)}, sum2 = ${counts(2)})" - } - - def newAggregator: VarianceAggregator = { - new VarianceAggregator() + s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})" } } From 6d32ccd97f251e873702cb5e6ff0fbb8072ccd9e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 23 Aug 2014 19:18:49 -0700 Subject: [PATCH 26/34] In DecisionTree.binsToBestSplit, changed loops to iterators to shorten code. --- .../spark/mllib/tree/DecisionTree.scala | 55 +++++-------------- 1 file changed, 14 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 39cdf9091f571..434283f4e4c79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -813,15 +813,9 @@ object DecisionTree extends Serializable with Logging { logDebug("node impurity = " + nodeImpurity) // For each (feature, split), calculate the gain, and select the best (feature, split). - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - var featureIndex = 0 - // TODO: Change loops over splits into iterators. - while (featureIndex < metadata.numFeatures) { + Range(0, metadata.numFeatures).map { featureIndex => val numSplits = metadata.numSplits(featureIndex) - if (metadata.isContinuous(featureIndex)) { + val (bestSplitIndex, bestGainStats) = if (metadata.isContinuous(featureIndex)) { //println(s"binsToBestSplit: feature $featureIndex (continuous)") // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for @@ -833,39 +827,26 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } // Find best split. - splitIndex = 0 - while (splitIndex < numSplits) { - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex - } - splitIndex += 1 - } + (splitIdx, gainStats) + }.maxBy(_._2.gain) } else if (metadata.isUnordered(featureIndex)) { //println(s"binsToBestSplit: feature $featureIndex (unordered cat)") // Unordered categorical feature val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { + Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - //println(s"\t split $splitIndex: gain: ${bestGainStats.gain}") - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex - } - splitIndex += 1 - } + (splitIndex, gainStats) + }.maxBy(_._2.gain) } else { //println(s"binsToBestSplit: feature $featureIndex (ordered cat)") // Ordered categorical feature @@ -880,25 +861,17 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } // Find best split. - splitIndex = 0 - while (splitIndex < numSplits) { + Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - //println(s"\t split $splitIndex: gain: ${bestGainStats.gain}") - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex - } - splitIndex += 1 - } + (splitIndex, gainStats) + }.maxBy(_._2.gain) } - featureIndex += 1 - } - (bestFeatureIndex, bestSplitIndex, bestGainStats) + (featureIndex, bestSplitIndex, bestGainStats) + }.maxBy(_._3.gain) } /** From 105f8ab5cc5f203fb47b2715d9e86d728be50eef Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 25 Aug 2014 09:49:49 -0700 Subject: [PATCH 27/34] Removed commented-out getEmptyBinAggregates from DecisionTree --- .../spark/mllib/tree/DecisionTree.scala | 39 ------------------- 1 file changed, 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 434283f4e4c79..86f95e5448044 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -891,45 +891,6 @@ object DecisionTree extends Serializable with Logging { } } - /** - * Get an empty instance of bin aggregates. - * For ordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(binIndex). - * For unordered features, aggregate is indexed by: (nodeIndex)(featureIndex)(2 * binIndex), - * where the bins are ordered as (numBins left bins, numBins right bins). - */ - /* - private def getEmptyBinAggregates( - metadata: DecisionTreeMetadata, - numNodes: Int): Array[Array[Array[ImpurityAggregator]]] = { - val impurityAggregator: ImpurityAggregator = metadata.impurity match { - case Gini => new GiniAggregator(metadata.numClasses) - case Entropy => new EntropyAggregator(metadata.numClasses) - case Variance => new VarianceAggregator() - case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") - } - - val agg = Array.fill[Array[ImpurityAggregator]](numNodes, metadata.numFeatures)( - new Array[ImpurityAggregator](0)) - var nodeIndex = 0 - while (nodeIndex < numNodes) { - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - val binMultiplier = if (metadata.isUnordered(featureIndex)) 2 else 1 - val effNumBins = metadata.numBins(featureIndex) * binMultiplier - agg(nodeIndex)(featureIndex) = new Array[ImpurityAggregator](effNumBins) - var binIndex = 0 - while (binIndex < effNumBins) { - agg(nodeIndex)(featureIndex)(binIndex) = impurityAggregator.newAggregator - binIndex += 1 - } - featureIndex += 1 - } - nodeIndex += 1 - } - agg - } - */ - /** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. From 37ca8459f4237427a978ab2dd2955df0ebcc0f0f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 25 Aug 2014 14:00:35 -0700 Subject: [PATCH 28/34] Fixed problem with how DecisionTree handles ordered categorical features. Background: Ordering refers to how nodes can split on categorical features. See the docs for more info. Previously, it chose a fixed ordering a priori. Now, it chooses a different ordering for each node, using the aggregated stats. For binary classification and regression, this is a bug fix. For multiclass classification, this improves a heuristic. Related changes: * splits and bins are no longer pre-computed for ordered categorical features since they are not needed. --- .../spark/mllib/tree/DecisionTree.scala | 318 ++++++++---------- .../spark/mllib/tree/impl/TreePoint.scala | 60 +--- .../spark/mllib/tree/DecisionTreeSuite.scala | 167 ++------- 3 files changed, 183 insertions(+), 362 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 86f95e5448044..094aabe0e6fc4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -72,9 +72,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // of the input data. timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - val numBins = bins(0).length timer.stop("findSplitsBins") - logDebug("numBins = " + numBins) + logDebug("numBins: feature: number of bins") + Range(0, metadata.numFeatures).foreach { featureIndex => + logDebug(s"\t$featureIndex\t${metadata.numBins(featureIndex)}") + } // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -96,7 +98,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") // TODO: Calculate memory usage more precisely. - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -474,12 +476,7 @@ object DecisionTree extends Serializable with Logging { featureValueUpperBound <= node.split.get.threshold } case Categorical => { - val featureValue = if (unorderedFeatures.contains(featureIndex)) { - binnedFeatures(featureIndex) - } else { - val binIndex = binnedFeatures(featureIndex) - bins(featureIndex)(binIndex).category - } + val featureValue = binnedFeatures(featureIndex) node.split.get.categories.contains(featureValue) } case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") @@ -515,7 +512,9 @@ object DecisionTree extends Serializable with Logging { def someUnorderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int, bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Unit = { + nodeIndex: Int, + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Unit = { // Iterate over all features. val numFeatures = treePoint.binnedFeatures.size val nodeOffset = agg.getNodeOffset(nodeIndex) @@ -632,23 +631,19 @@ object DecisionTree extends Serializable with Logging { val numNodes = Node.maxNodesInLevel(level) / numGroups logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample. - val numFeatures = metadata.numFeatures - logDebug("numFeatures = " + numFeatures) - - val numClasses = metadata.numClasses - logDebug("numClasses = " + numClasses) - - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) - - val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures - logDebug("isMulticlassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - // Used for treePointToNodeIndex + // Used for treePointToNodeIndex to get an index for this (level, group). + // - Node.maxNodesInSubtree(level - 1) corrects for nodes before this level. + // - groupShift corrects for groups in this level before the current group. + // - 1 corrects for the fact that global node indices start at 1, not 0. val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1 /** @@ -660,11 +655,8 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { 0 } else { - val globalNodeIndex = predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) - // Get index for this (level, group). - // - levelOffset corrects for nodes before this level. - // - groupShift corrects for groups in this level before the current group. - // - 1 corrects for the fact that globalNodeIndex starts at 1, not 0. + val globalNodeIndex = + predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) globalNodeIndex - globalNodeIndexOffset } } @@ -712,24 +704,6 @@ object DecisionTree extends Serializable with Logging { } timer.stop("aggregation") - /* - println("binAggregates for unordered cats:") - for (n <- Range(0, numNodes)) { - for (f <- Range(0, numFeatures)) { - val numBins = metadata.numBins(f) - for (b <- Range(0, numBins)) { - val (leftOffset, rightOffset) = binAggregates.getLeftRightNodeFeatureOffsets(n, f) - val leftCalc = binAggregates.getImpurityCalculator(leftOffset, b) - val rightCalc = binAggregates.getImpurityCalculator(rightOffset, b) - println(s"\t bin(n:$n)(f:$f)(b:$b)(left): $leftCalc") - println(s"\t bin(n:$n)(f:$f)(b:$b)(right): $rightCalc") - } - } - } - println("binAggregates, flat array:") - binAggregates.allStats.foreach(x => println(s"\t $x")) - */ - // Calculate best splits for all nodes at a given level timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) @@ -738,11 +712,9 @@ object DecisionTree extends Serializable with Logging { while (nodeIndex < numNodes) { val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) logDebug("node impurity = " + nodeImpurity) - val (bestFeatureIndex, bestSplitIndex, bestGain) = - binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata) - bestSplits(nodeIndex) = (splits(bestFeatureIndex)(bestSplitIndex), bestGain) - logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + bestSplits(nodeIndex) = + binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) + logDebug("best split = " + bestSplits(nodeIndex)._1) nodeIndex += 1 } timer.stop("chooseSplits") @@ -801,22 +773,22 @@ object DecisionTree extends Serializable with Logging { * @param binAggregates Bin statistics. * @param nodeIndex Index for node to split in this (level, group). * @param nodeImpurity Impurity of the node (nodeIndex). - * @return tuple (best feature index, best split index, information gain) + * @return tuple for best split: (Split, information gain) */ def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, nodeImpurity: Double, level: Int, - metadata: DecisionTreeMetadata): (Int, Int, InformationGainStats) = { + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) // For each (feature, split), calculate the gain, and select the best (feature, split). Range(0, metadata.numFeatures).map { featureIndex => val numSplits = metadata.numSplits(featureIndex) - val (bestSplitIndex, bestGainStats) = if (metadata.isContinuous(featureIndex)) { - //println(s"binsToBestSplit: feature $featureIndex (continuous)") + if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. @@ -827,67 +799,119 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } // Find best split. - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - (splitIdx, gainStats) - }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { - //println(s"binsToBestSplit: feature $featureIndex (unordered cat)") // Unordered categorical feature val (leftChildOffset, rightChildOffset) = binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - (splitIndex, gainStats) - }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { - //println(s"binsToBestSplit: feature $featureIndex (ordered cat)") // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val numBins = metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.predict + } else { + Double.MaxValue + } + (featureValue, centroid) + } + } + + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) + + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - // TODO: Choose adaptive ordering for ordered categorical features, and compute cumulative sum. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { - binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory) splitIndex += 1 } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 // Find best split. - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) - (splitIndex, gainStats) - }.maxBy(_._2.gain) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - (featureIndex, bestSplitIndex, bestGainStats) - }.maxBy(_._3.gain) + }.maxBy(_._2.gain) } /** * Get the number of values to be stored per node in the bin aggregates. - * - * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { + val totalBins = metadata.numBins.sum if (metadata.isClassification) { if (metadata.isMulticlassWithCategoricalFeatures) { - 2 * metadata.numClasses * numBins * metadata.numFeatures + 2 * metadata.numClasses * totalBins } else { - metadata.numClasses * numBins * metadata.numFeatures + metadata.numClasses * totalBins } } else { - 3 * numBins * metadata.numFeatures + 3 * totalBins } } @@ -898,6 +922,7 @@ object DecisionTree extends Serializable with Logging { * Continuous features: * For each feature, there are numBins - 1 possible splits representing the possible binary * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. * * Categorical features: * For each feature, there is 1 bin per split. @@ -923,43 +948,41 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) + logDebug("isMulticlass = " + metadata.isMulticlass) val numFeatures = metadata.numFeatures - // Calculate the number of sample for approximate quantile calculation. - val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) - val fraction = if (requiredSamples < metadata.numExamples) { - requiredSamples.toDouble / metadata.numExamples + // Sample the input only if there are continuous features. + val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val sampledInput = if (hasContinuousFeatures) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) + input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() } else { - 1.0 + new Array[LabeledPoint](0) } - logDebug("fraction of data used for calculating quantiles = " + fraction) - - // sampled input for RDD calculation - val sampledInput = - input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() - val numSamples = sampledInput.length metadata.quantileStrategy match { case Sort => val splits = new Array[Array[Split]](numFeatures) val bins = new Array[Array[Bin]](numFeatures) - var i = 0 - while (i < numFeatures) { - splits(i) = new Array[Split](metadata.numSplits(i)) - bins(i) = new Array[Bin](metadata.numBins(i)) - i += 1 - } // Find all splits. - // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { + val numSamples = sampledInput.length + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) logDebug("stride = " + stride) @@ -985,6 +1008,8 @@ object DecisionTree extends Serializable with Logging { if (metadata.isUnordered(featureIndex)) { // Unordered features: low-arity features in multiclass classification // 2^(maxFeatureValue- 1) - 1 combinations + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 while (splitIndex < numSplits) { val categories: List[Double] = @@ -1009,72 +1034,11 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } } else { - // Ordered features: high-arity features, or not multiclass classification - /* For a given categorical feature, use a subsample of the data - * to choose how to arrange possible splits. - * This examines each category and computes a centroid. - * These centroids are later used to sort the possible splits. - * centroidForCategories is a mapping: category (for the given feature) --> centroid - */ - val centroidForCategories = { - if (isMulticlass) { - // For categorical variables in multiclass classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the impurity of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) - .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - } - } - - logDebug("centroid for categories = " + centroidForCategories.mkString(",")) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until featureArity) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((category, value), binIndex) => - categoriesForSplit = category :: categoriesForSplit - if (binIndex < numSplits) { - splits(featureIndex)(binIndex) = - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) - } - bins(featureIndex)(binIndex) = { - if (binIndex == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, category) - } else if (binIndex == numSplits) { - new Bin(splits(featureIndex)(binIndex - 1), - new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit), - Categorical, category) - } else { - new Bin(splits(featureIndex)(binIndex - 1), splits(featureIndex)(binIndex), - Categorical, category) - } - } - } + // Ordered features: high-arity features, or not multiclass classification + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + splits(featureIndex) = new Array[Split](0) + bins(featureIndex) = new Array[Bin](0) } } featureIndex += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 54b099fb24b59..3bba26257b155 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -120,43 +120,6 @@ private[tree] object TreePoint { -1 } - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() @@ -167,19 +130,18 @@ private[tree] object TreePoint { } binIndex } else { - // Perform sequential search to find bin for categorical features. - val binIndex = if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - if (binIndex == -1) { - throw new RuntimeException("No bin was found for categorical feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. isUnorderedFeature = $isUnorderedFeature." + - s" Feature value: ${labeledPoint.features(featureIndex)}") + // Categorical feature bins are indexed by feature values. + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) } - binIndex + featureValue.toInt } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 75dea3556a403..b2cdc4b94bfdd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -74,7 +74,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) } - test("Binary classification with binary features: split and bin calculation") { + test("Binary classification with binary (ordered) categorical features:" + + " split and bin calculation") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -92,47 +93,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 1) - assert(bins(0).length === 2) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - // Check bins. - - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(0.0)) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("Binary classification with 3-category features, with no samples for one category") { + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -150,57 +117,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 2) - assert(bins(0).length === 3) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(0.0)) - assert(splits(0)(1).categories.contains(1.0)) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(0.0)) - assert(splits(1)(1).categories.contains(1.0)) - - // Check bins. - - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories === splits(0)(0).categories) - - assert(bins(0)(1).lowSplit.categories === splits(0)(0).categories) - assert(bins(0)(1).highSplit.categories === splits(0)(1).categories) - - assert(bins(0)(2).lowSplit.categories === splits(0)(1).categories) - - assert(bins(0)(2).highSplit.categories === List(2.0, 0.0, 1.0)) - - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories === splits(1)(0).categories) - - assert(bins(1)(1).lowSplit.categories === splits(1)(0).categories) - assert(bins(1)(1).highSplit.categories === splits(1)(1).categories) - - assert(bins(1)(2).lowSplit.categories === splits(1)(1).categories) - assert(bins(1)(2).highSplit.categories === List(2.0, 1.0, 0.0)) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } test("extract categories from a number for multiclass classification") { @@ -315,6 +234,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + // 2^10 - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -322,46 +242,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 9) - assert(bins(0).length === 10) - - // 2^10 - 1 > 100, so categorical features will be ordered - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(2.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(2.0)) - assert(splits(0)(2).categories.contains(1.0)) - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).category === 2.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(2.0)) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("Binary classification stump with all ordered categorical features") { + test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -379,8 +266,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 2) - assert(bins(0).length === 3) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, @@ -398,7 +286,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("Regression stump with 3-ary categorical features") { + test("Regression stump with 3-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -430,7 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("Regression stump with binary categorical features") { + test("Regression stump with binary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -445,6 +333,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) + println(model) validateRegressor(model, arr, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -617,7 +506,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } } - test("Multiclass classification stump with 3-ary categorical features") { + test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -674,7 +563,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.topNode.split.get.feature === 1) } - test("Multiclass classification stump with categorical features, with just enough bins") { + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -683,6 +573,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) @@ -731,13 +623,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("Multiclass classification stump with continuous + categorical features") { + test("Multiclass classification stump with continuous + unordered categorical features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) @@ -763,6 +656,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) @@ -771,6 +666,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 + println(s"bestSplit: $bestSplit") assert(bestSplit.feature === 0) assert(bestSplit.categories.length === 1) assert(bestSplit.categories.contains(1.0)) @@ -883,5 +779,4 @@ object DecisionTreeSuite { arr } - } From e676da174b1b789d0d01f06224eeab8be97ca68c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 25 Aug 2014 19:36:24 -0700 Subject: [PATCH 29/34] Updated documentation for DecisionTree --- .../examples/mllib/DecisionTreeRunner.scala | 3 +- .../spark/mllib/tree/DecisionTree.scala | 68 ++++++++++--------- .../mllib/tree/impl/DTStatsAggregator.scala | 28 +++++--- .../tree/impl/DecisionTreeMetadata.scala | 30 +++----- .../spark/mllib/tree/impurity/Entropy.scala | 27 ++++++++ .../spark/mllib/tree/impurity/Gini.scala | 27 ++++++++ .../spark/mllib/tree/impurity/Impurity.scala | 30 ++++++-- .../spark/mllib/tree/impurity/Variance.scala | 23 +++++++ .../apache/spark/mllib/tree/model/Node.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 2 - 10 files changed, 170 insertions(+), 70 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 06771288fda96..cf3d2cca81ff6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -156,9 +156,8 @@ object DecisionTreeRunner { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - println("opt3") // Split into training, test. - val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest), seed = 12345) + val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) val training = splits(0).cache() val test = splits(1).cache() val numTraining = training.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 094aabe0e6fc4..eca1284cf24dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, DTStatsAggregator, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl._ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ @@ -122,7 +122,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var break = false while (level <= maxDepth && !break) { - //println(s"LEVEL $level") logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") @@ -198,14 +197,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - println(s"$timer") new DecisionTreeModel(topNode, strategy.algo) } } - object DecisionTree extends Serializable with Logging { /** @@ -456,13 +453,21 @@ object DecisionTree extends Serializable with Logging { * This function mimics prediction, passing an example from the root node down to a node * at the current level being trained; that node's index is returned. * + * @param node Node in tree from which to classify the given data point. + * @param binnedFeatures Binned feature vector for data point. + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. * @return Leaf index if the data point reaches a leaf. * Otherwise, last node reachable in tree matching this example. * Note: This is the global node index, i.e., the index used in the tree. * This index is different from the index used during training a particular * set of nodes in a (level, group). */ - def predictNodeIndex(node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { + def predictNodeIndex( + node: Node, + binnedFeatures: Array[Int], + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Int = { if (node.isLeaf) { node.id } else { @@ -499,15 +504,18 @@ object DecisionTree extends Serializable with Logging { } /** - * Helper for binSeqOp. + * Helper for binSeqOp, for data containing some unordered (categorical) features. * - * @param agg Array storing aggregate calculation. - * For ordered features, this is of size: - * numClasses * numBins * numFeatures * numNodes. - * For unordered features, this is of size: - * 2 * numClasses * numBins * numFeatures * numNodes. - * @param treePoint Data point being aggregated. + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. */ def someUnorderedBinSeqOp( agg: DTStatsAggregator, @@ -547,15 +555,13 @@ object DecisionTree extends Serializable with Logging { } /** - * Helper for binSeqOp: for regression and for classification with only ordered features. + * Helper for binSeqOp, for regression and for classification with only ordered features. * - * Performs a sequential aggregation over a partition for regression. - * For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. + * For each feature, the sufficient statistics of one bin are updated. * - * @param agg Array storing aggregate calculation, updated by this function. - * Size: 3 * numBins * numFeatures * numNodes - * @param treePoint Data point being aggregated. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ @@ -582,6 +588,7 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree + * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. @@ -663,19 +670,12 @@ object DecisionTree extends Serializable with Logging { /** * Performs a sequential aggregation over a partition. - * For l nodes, k features, - * For classification: - * Either the left count or the right count of one of the bins is - * incremented based upon whether the feature is classified as 0 or 1. - * For regression: - * The count, sum, sum of squares of one of the bins is incremented. * - * @param agg Array storing aggregate calculation, updated by this function. - * Size for classification: - * Ordered features: numNodes * numFeatures * numBins. - * Unordered features: (2 * numNodes) * numFeatures * numBins. - * Size for regression: - * numNodes * numFeatures * numBins. + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). * @param treePoint Data point being aggregated. * @return agg */ @@ -883,8 +883,10 @@ object DecisionTree extends Serializable with Logging { val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 70a1c8e6dff24..8b7ab2d4dd267 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -19,20 +19,18 @@ package org.apache.spark.mllib.tree.impl import org.apache.spark.mllib.tree.impurity._ -import scala.collection.mutable - - /** - * :: Experimental :: * DecisionTree statistics aggregator. * This holds a flat array of statistics for a set of (nodes, features, bins) * and helps with indexing. - * TODO: Allow views of Vector types to replace some of the code in here. */ private[tree] class DTStatsAggregator( metadata: DecisionTreeMetadata, val numNodes: Int) extends Serializable { + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) @@ -40,15 +38,27 @@ private[tree] class DTStatsAggregator( case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ val statsSize: Int = impurityAggregator.statsSize val numFeatures: Int = metadata.numFeatures + /** + * Number of bins for each feature. This is indexed by the feature index. + */ val numBins: Array[Int] = metadata.numBins - val isUnordered: Array[Boolean] = - Range(0, numFeatures).map(f => metadata.unorderedFeatures.contains(f)).toArray + /** + * Indicator for each feature of whether that feature is an unordered feature. + * TODO: Is Array[Boolean] any faster? + */ + def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ private val featureOffsets: Array[Int] = { def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { if (isUnordered(featureIndex)) { @@ -105,8 +115,9 @@ private[tree] class DTStatsAggregator( def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride /** + * Faster version of [[update]]. * Update the stats for a given (node, feature, bin) for ordered features, using the given label. - * This uses a pre-computed node offset from [[getNodeOffset]]. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. */ def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize @@ -137,6 +148,7 @@ private[tree] class DTStatsAggregator( } /** + * Faster version of [[update]]. * Update the stats for a given (node, feature, bin), using the given label. * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset * from [[getNodeFeatureOffset]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 4fca7a4e4eb98..fd630624c11c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -24,30 +24,17 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.rdd.RDD - -/* - * TODO: Add doc about ordered vs. unordered features. - * Ensure numBins is always greater than the categories. For multiclass classification, - * numBins should be greater than math.pow(2, maxCategories - 1) - 1. - * It's a limitation of the current implementation but a reasonable trade-off since features - * with large number of categories get favored over continuous features. - * - * This needs to be checked here instead of in Strategy since numBins can be determined - * by the number of training examples. - */ - - /** * Learning and dataset metadata for DecisionTree. * * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. * @param featureArity Map: categorical feature index --> arity. * I.e., the feature takes values in {0, ..., arity - 1}. - * @param numBins numBins(featureIndex) = number of bins for feature + * @param numBins Number of bins for each feature. */ private[tree] class DecisionTreeMetadata( val numFeatures: Int, @@ -82,6 +69,11 @@ private[tree] class DecisionTreeMetadata( private[tree] object DecisionTreeMetadata { + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size @@ -94,6 +86,9 @@ private[tree] object DecisionTreeMetadata { val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt val log2MaxPossibleBinsp1 = math.log(maxPossibleBins + 1) / math.log(2.0) + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. val unorderedFeatures = new mutable.HashSet[Int]() val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { @@ -104,11 +99,6 @@ private[tree] object DecisionTreeMetadata { unorderedFeatures.add(f) numBins(f) = numUnorderedBins(k) } else { - // TODO: Check the below k <= maxBins. - // Checking k <= maxPossibleBins should work. - // However, there may have been a 1-off error later on allocating 1 extra - // (unused) bin. - // TODO: Allow this case, where we simply will know nothing about some categories? require(k <= maxPossibleBins, s"maxBins (= $maxPossibleBins) should be greater than max categories " + s"in categorical features (>= $k)") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index e1667d474f676..1c8afc2d0f4bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -75,6 +75,12 @@ object Entropy extends Impurity { } +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ private[tree] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { @@ -102,20 +108,41 @@ private[tree] class EntropyAggregator(numClasses: Int) } +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) + /** + * Calculate the impurity from the stored sufficient statistics. + */ def calculate(): Double = Entropy.calculate(stats, stats.sum) + /** + * Number of data points accounted for in the sufficient statistics. + */ def count: Long = stats.sum.toLong + /** + * Prediction which should be made based on the sufficient statistics. + */ def predict: Double = if (count == 0) { 0 } else { indexOfLargestArrayElement(stats) } + /** + * Probability of the label given by [[predict]]. + */ override def prob(label: Double): Double = { val lbl = label.toInt require(lbl < stats.length, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d2b3fe3df576d..5cfdf345d163c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -71,6 +71,12 @@ object Gini extends Impurity { } +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ private[tree] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { @@ -98,20 +104,41 @@ private[tree] class GiniAggregator(numClasses: Int) } +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ def copy: GiniCalculator = new GiniCalculator(stats.clone()) + /** + * Calculate the impurity from the stored sufficient statistics. + */ def calculate(): Double = Gini.calculate(stats, stats.sum) + /** + * Number of data points accounted for in the sufficient statistics. + */ def count: Long = stats.sum.toLong + /** + * Prediction which should be made based on the sufficient statistics. + */ def predict: Double = if (count == 0) { 0 } else { indexOfLargestArrayElement(stats) } + /** + * Probability of the label given by [[predict]]. + */ override def prob(label: Double): Double = { val lbl = label.toInt require(lbl < stats.length, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 2954679ea4546..5a047d6cb5480 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} /** * :: Experimental :: * Trait for calculating information gain. + * This trait is used for + * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] + * (b) calculating impurity values from sufficient statistics. */ @Experimental trait Impurity extends Serializable { @@ -51,8 +54,8 @@ trait Impurity extends Serializable { /** * Interface for updating views of a vector of sufficient statistics, * in order to compute impurity from a sample. - * Note: Instances of this class do not hold the data itself. - * @param statsSize Length of the vector of sufficient statistics. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param statsSize Length of the vector of sufficient statistics for one bin. */ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { @@ -88,14 +91,20 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri /** * Stores statistics for one (node, feature, bin) for calculating impurity. - * Unlike [[ImpurityAggregator]], this class stores its own data and is for a single + * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific * (node, feature, bin). - * @param stats Array of sufficient statistics. + * @param stats Array of sufficient statistics for a (node, feature, bin). */ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ def copy: ImpurityCalculator + /** + * Calculate the impurity from the stored sufficient statistics. + */ def calculate(): Double /** @@ -129,12 +138,25 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { this } + /** + * Number of data points accounted for in the sufficient statistics. + */ def count: Long + /** + * Prediction which should be made based on the sufficient statistics. + */ def predict: Double + /** + * Probability of the label given by [[predict]], or -1 if no probability is available. + */ def prob(label: Double): Double = -1 + /** + * Return the index of the largest array element. + * Fails if the array is empty. + */ protected def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 0386db0a7d422..e9ccecb1b8067 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -62,6 +62,11 @@ object Variance extends Impurity { } +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + */ private[tree] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { @@ -87,18 +92,36 @@ private[tree] class VarianceAggregator() } +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + s" but was given array of length ${stats.size}.") + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) + /** + * Calculate the impurity from the stored sufficient statistics. + */ def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) + /** + * Number of data points accounted for in the sufficient statistics. + */ def count: Long = stats(0).toLong + /** + * Prediction which should be made based on the sufficient statistics. + */ def predict: Double = if (count == 0) { 0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 43023f31e0286..5cd4ad19b966f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -27,7 +27,7 @@ import org.apache.spark.mllib.linalg.Vector * Node in a decision tree. * * About node indexing: - * Nodes are indexed from 1. Node 1 is the root; nodes 2,3 are the left,right children. + * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children. * Node index 0 is not used. * * @param id integer node id, from 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index b2cdc4b94bfdd..9a7aa9ecea5c9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -333,7 +333,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - println(model) validateRegressor(model, arr, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -666,7 +665,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 - println(s"bestSplit: $bestSplit") assert(bestSplit.feature === 0) assert(bestSplit.categories.length === 1) assert(bestSplit.categories.contains(1.0)) From 1485fcc2de396a88b30e0fad7d910a193c24a363 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 26 Aug 2014 19:34:23 -0700 Subject: [PATCH 30/34] Made some DecisionTree methods private. --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 12 ++++++------ .../org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index eca1284cf24dc..4e8f9d5eff4ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -415,7 +415,7 @@ object DecisionTree extends Serializable with Logging { * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array (over nodes) of splits with best split for each node at a given level. */ - protected[tree] def findBestSplits( + private[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], metadata: DecisionTreeMetadata, @@ -463,7 +463,7 @@ object DecisionTree extends Serializable with Logging { * This index is different from the index used during training a particular * set of nodes in a (level, group). */ - def predictNodeIndex( + private def predictNodeIndex( node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], @@ -517,7 +517,7 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. */ - def someUnorderedBinSeqOp( + private def someUnorderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int, @@ -565,7 +565,7 @@ object DecisionTree extends Serializable with Logging { * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def orderedBinSeqOp( + private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int): Unit = { @@ -729,7 +729,7 @@ object DecisionTree extends Serializable with Logging { * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ - def calculateGainForSplit( + private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, topImpurity: Double, @@ -775,7 +775,7 @@ object DecisionTree extends Serializable with Logging { * @param nodeImpurity Impurity of the node (nodeIndex). * @return tuple for best split: (Split, information gain) */ - def binsToBestSplit( + private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, nodeImpurity: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5cd4ad19b966f..b0313c48756e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -182,13 +182,13 @@ private[tree] object Node { * Return the maximum number of nodes which can be in the given level of the tree. * @param level Level of tree (0 = root). */ - private[tree] def maxNodesInLevel(level: Int): Int = 1 << level + def maxNodesInLevel(level: Int): Int = 1 << level /** * Return the maximum number of nodes which can be in or above the given level of the tree * (i.e., for the entire subtree from the root to this level). * @param level Level of tree (0 = root). */ - private[tree] def maxNodesInSubtree(level: Int): Int = (1 << level + 1) - 1 + def maxNodesInSubtree(level: Int): Int = (1 << level + 1) - 1 } From 46511547be6d193d19356ac31a978a83fe27e0b5 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 2 Sep 2014 15:04:58 -0700 Subject: [PATCH 31/34] Changed numBins semantics for unordered features. * Before: numBins = numSplits = (1 << k - 1) - 1 * Now: numBins = 2 * numSplits = 2 * [(1 << k - 1) - 1] * This also involved changing the semantics of: ** DecisionTreeMetadata.numUnorderedBins() Also made other small cleanups based on code review. --- .../spark/mllib/tree/DecisionTree.scala | 65 +++++++++---------- .../mllib/tree/impl/DTStatsAggregator.scala | 9 ++- .../tree/impl/DecisionTreeMetadata.scala | 52 +++++++++------ .../apache/spark/mllib/tree/model/Bin.scala | 7 +- .../apache/spark/mllib/tree/model/Node.scala | 9 ++- .../spark/mllib/tree/DecisionTreeSuite.scala | 4 +- 6 files changed, 82 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4e8f9d5eff4ad..712236ad62813 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -74,9 +74,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") - Range(0, metadata.numFeatures).foreach { featureIndex => - logDebug(s"\t$featureIndex\t${metadata.numBins(featureIndex)}") - } + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. @@ -85,12 +85,14 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth - // the max number of nodes possible given the depth of the tree, plus 1 - val maxNumNodes_p1 = Node.maxNodesInLevel(maxDepth + 1) + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 + val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodes_p1) + val parentImpurities = new Array[Double](maxNumNodesPlus1) // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodes_p1) + val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction @@ -126,7 +128,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("level = " + level) logDebug("#####################################") - // Find best split for all nodes at a level. timer.start("findBestSplits") val splitsStatsForLevel: Array[(Split, InformationGainStats)] = @@ -134,9 +135,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + val levelNodeIndexOffset = Node.startIndexInLevel(level) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - val nodeIndex = levelNodeIndexOffset + index + 1 // + 1 since nodes indexed from 1 + val nodeIndex = levelNodeIndexOffset + index // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") @@ -504,7 +505,7 @@ object DecisionTree extends Serializable with Logging { } /** - * Helper for binSeqOp, for data containing some unordered (categorical) features. + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. * * For ordered features, a single bin is updated. * For unordered features, bins correspond to subsets of categories; either the left or right bin @@ -517,7 +518,7 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. */ - private def someUnorderedBinSeqOp( + private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int, @@ -533,17 +534,16 @@ object DecisionTree extends Serializable with Logging { val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) - // Update the left or right count for one bin. - // Find all matching bins and increment their values. - val numCategoricalBins = agg.numBins(featureIndex) - var binIndex = 0 - while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, binIndex, treePoint.label) + // Update the left or right bin for each split. + val numSplits = agg.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, binIndex, treePoint.label) + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) } - binIndex += 1 + splitIndex += 1 } } else { // Ordered feature @@ -648,10 +648,9 @@ object DecisionTree extends Serializable with Logging { val groupShift = numNodes * groupIndex // Used for treePointToNodeIndex to get an index for this (level, group). - // - Node.maxNodesInSubtree(level - 1) corrects for nodes before this level. + // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. // - groupShift corrects for groups in this level before the current group. - // - 1 corrects for the fact that global node indices start at 1, not 0. - val globalNodeIndexOffset = Node.maxNodesInSubtree(level - 1) + groupShift + 1 + val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift /** * Find the node index for the given example. @@ -690,7 +689,7 @@ object DecisionTree extends Serializable with Logging { if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg, treePoint, nodeIndex) } else { - someUnorderedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) + mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) } } agg @@ -907,11 +906,7 @@ object DecisionTree extends Serializable with Logging { private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { val totalBins = metadata.numBins.sum if (metadata.isClassification) { - if (metadata.isMulticlassWithCategoricalFeatures) { - 2 * metadata.numClasses * totalBins - } else { - metadata.numClasses * totalBins - } + metadata.numClasses * totalBins } else { 3 * totalBins } @@ -1008,8 +1003,12 @@ object DecisionTree extends Serializable with Logging { // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // Unordered features: low-arity features in multiclass classification - // 2^(maxFeatureValue- 1) - 1 combinations + // TODO: The second half of the bins are unused. Actually, we could just use + // splits and not build bins for unordered features. That should be part of + // a later PR since it will require changing other code (using splits instead + // of bins in a few places). + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) var splitIndex = 0 @@ -1036,7 +1035,7 @@ object DecisionTree extends Serializable with Logging { splitIndex += 1 } } else { - // Ordered features: high-arity features, or not multiclass classification + // Ordered features // Bins correspond to feature values, so we do not need to compute splits or bins // beforehand. Splits are constructed as needed during training. splits(featureIndex) = new Array[Split](0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 8b7ab2d4dd267..3461d7724e530 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -25,7 +25,7 @@ import org.apache.spark.mllib.tree.impurity._ * and helps with indexing. */ private[tree] class DTStatsAggregator( - metadata: DecisionTreeMetadata, + val metadata: DecisionTreeMetadata, val numNodes: Int) extends Serializable { /** @@ -50,6 +50,11 @@ private[tree] class DTStatsAggregator( */ val numBins: Array[Int] = metadata.numBins + /** + * Number of splits for the given feature. + */ + def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) + /** * Indicator for each feature of whether that feature is an unordered feature. * TODO: Is Array[Boolean] any faster? @@ -142,7 +147,7 @@ private[tree] class DTStatsAggregator( def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") + s" but was called for ordered feature $featureIndex.") val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) (baseOffset, baseOffset + numBins(featureIndex) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index fd630624c11c8..e95add7558bcf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -59,8 +59,13 @@ private[tree] class DecisionTreeMetadata( def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + /** + * Number of splits for the given feature. + * For unordered features, there are 2 bins per split. + * For ordered features, there is 1 more bin than split. + */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) + numBins(featureIndex) >> 1 } else { numBins(featureIndex) - 1 } @@ -84,33 +89,39 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt - val log2MaxPossibleBinsp1 = math.log(maxPossibleBins + 1) / math.log(2.0) // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + + s"in categorical features (= $maxCategoriesPerFeature)") + } + val unorderedFeatures = new mutable.HashSet[Int]() val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - if (k - 1 < log2MaxPossibleBinsp1) { - // Note: The above check is equivalent to checking: - // numUnorderedBins = (1 << k - 1) - 1 < maxBins - unorderedFeatures.add(f) - numBins(f) = numUnorderedBins(k) + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) } else { - require(k <= maxPossibleBins, - s"maxBins (= $maxPossibleBins) should be greater than max categories " + - s"in categorical features (>= $k)") - numBins(f) = k + numBins(featureIndex) = numCategories } } } else { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - require(k <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= ${strategy.categoricalFeaturesInfo.values.max})") - numBins(f) = k + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + numBins(featureIndex) = numCategories } } @@ -122,9 +133,10 @@ private[tree] object DecisionTreeMetadata { /** * Given the arity of a categorical feature (arity = number of categories), * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = { - (1 << arity - 1) - 1 - } + def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index af35d88f713e5..0cad473782af1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. + * Used for "binning" the feature values for faster best split calculation. * * For a continuous feature, the bin is determined by a low and a high split, * where an example with featureValue falls into the bin s.t. @@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * bins, splits, and feature values. The bin is determined by category/feature value. * However, the bins are not necessarily ordered by feature value; * they are ordered using impurity. + * * For unordered categorical features, there is a 1-1 correspondence between bins, splits, * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all + * partitionings of categories into 2 disjoint, non-empty sets. * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin + * accepted in the bin * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin for ordered features */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index b0313c48756e6..5b8a4cbed2306 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -151,12 +151,12 @@ private[tree] object Node { /** * Return the index of the left child of this node. */ - def leftChildIndex(nodeIndex: Int): Int = nodeIndex * 2 + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 /** * Return the index of the right child of this node. */ - def rightChildIndex(nodeIndex: Int): Int = nodeIndex * 2 + 1 + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 /** * Get the parent index of the given node, or 0 if it is the root. @@ -185,10 +185,9 @@ private[tree] object Node { def maxNodesInLevel(level: Int): Int = 1 << level /** - * Return the maximum number of nodes which can be in or above the given level of the tree - * (i.e., for the entire subtree from the root to this level). + * Return the index of the first node in the given level. * @param level Level of tree (0 = root). */ - def maxNodesInSubtree(level: Int): Int = (1 << level + 1) - 1 + def startIndexInLevel(level: Int): Int = 1 << level } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 9a7aa9ecea5c9..8e556c917b2e7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -148,7 +148,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 3) - assert(bins(0).length === 3) + assert(bins(0).length === 6) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -564,7 +564,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("Multiclass classification stump with unordered categorical features," + " with just enough bins") { - val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, From aa4e4df1989564f411e8e9e975618f6e715e2683 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 3 Sep 2014 12:15:01 -0700 Subject: [PATCH 32/34] Updated DTStatsAggregator with bug fix (nodeString should not be multiplied by statsSize) --- .../org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 3461d7724e530..866d85a79bea1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -78,7 +78,7 @@ private[tree] class DTStatsAggregator( /** * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. */ - private val nodeStride: Int = featureOffsets.last * statsSize + private val nodeStride: Int = featureOffsets.last /** * Total number of elements stored in this aggregator. From a2acea566449f07ce122f0904effe2eefe6c0f8a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 4 Sep 2014 17:39:15 -0700 Subject: [PATCH 33/34] Small optimizations based on profiling --- .../spark/mllib/tree/DecisionTree.scala | 2 +- .../spark/mllib/tree/impl/TreePoint.scala | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 712236ad62813..dd766c12d28a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -572,7 +572,7 @@ object DecisionTree extends Serializable with Logging { val label = treePoint.label val nodeOffset = agg.getNodeOffset(nodeIndex) // Iterate over all features. - val numFeatures = treePoint.binnedFeatures.size + val numFeatures = agg.numFeatures var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 3bba26257b155..47283fe4ac535 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -73,8 +73,9 @@ private[tree] object TreePoint { val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), - metadata.isUnordered(featureIndex), bins, metadata.featureArity) + val featureArity = metadata.featureArity.getOrElse(featureIndex, 0) + arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity, + metadata.isUnordered(featureIndex), bins) featureIndex += 1 } @@ -84,17 +85,16 @@ private[tree] object TreePoint { /** * Find bin for one (labeledPoint, feature). * + * @param featureArity 0 for continuous features; number of categories for categorical features. * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). - * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, + featureArity: Int, isUnorderedFeature: Boolean, - bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): Int = { + bins: Array[Array[Bin]]): Int = { /** * Binary search helper method for continuous feature. @@ -120,7 +120,7 @@ private[tree] object TreePoint { -1 } - if (isFeatureContinuous) { + if (featureArity == 0) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1) { @@ -131,13 +131,12 @@ private[tree] object TreePoint { binIndex } else { // Categorical feature bins are indexed by feature values. - val featureCategories = categoricalFeaturesInfo(featureIndex) val featureValue = labeledPoint.features(featureIndex) - if (featureValue < 0 || featureValue >= featureCategories) { + if (featureValue < 0 || featureValue >= featureArity) { throw new IllegalArgumentException( s"DecisionTree given invalid data:" + s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + + s" {0,...,${featureArity - 1}," + s" but a data point gives it value $featureValue.\n" + " Bad data point: " + labeledPoint.toString) } From 00e44049300f13d16dd9bdffeee2923ec6300502 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Sep 2014 15:30:11 -0700 Subject: [PATCH 34/34] optimization for TreePoint construction (pre-computing featureArity and isUnordered as arrays) --- .../spark/mllib/tree/impl/TreePoint.scala | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 47283fe4ac535..35e361ae309cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -55,30 +55,40 @@ private[tree] object TreePoint { input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { + // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) + val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) + isUnordered(featureIndex) = metadata.isUnordered(featureIndex) + featureIndex += 1 + } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, metadata) + TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) } } /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata DecisionTree training info, used for dataset metadata. + * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories + * for categorical features. + * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): TreePoint = { + featureArity: Array[Int], + isUnordered: Array[Boolean]): TreePoint = { val numFeatures = labeledPoint.features.size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - val featureArity = metadata.featureArity.getOrElse(featureIndex, 0) - arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity, - metadata.isUnordered(featureIndex), bins) + arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), + isUnordered(featureIndex), bins) featureIndex += 1 } - new TreePoint(labeledPoint.label, arr) }