diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10..17aba54f21bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -276,14 +276,10 @@ private[tree] class LearningNode( new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, - stats.impurityCalculator) - } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) - } - + assert(stats != null, "Unknown error during Decision Tree learning. Could not convert " + + "LearningNode to Node") + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) } } @@ -334,7 +330,7 @@ private[tree] object LearningNode { id: Int, isLeaf: Boolean, stats: ImpurityStats): LearningNode = { - new LearningNode(id, None, None, None, false, stats) + new LearningNode(id, None, None, None, isLeaf, stats) } /** Create an empty node with the given node index. Values must be set later on. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala new file mode 100644 index 000000000000..07e4a16e2990 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.Split + +/** + * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. + */ +private[impl] object AggUpdateUtils { + + /** + * Updates the parent node stats of the passed-in impurity aggregator with the labels + * corresponding to the feature values at indices [from, to). + * @param indices Array of row indices for feature values; indices(i) = row index of the ith + * feature value + */ + private[impl] def updateParentImpurity( + statsAggregator: DTStatsAggregator, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): Unit = { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + val label = labels(rowIndex) + statsAggregator.updateParent(label, instanceWeights(rowIndex)) + } + } + + /** + * Update aggregator for an (unordered feature, label) pair + * @param featureSplits Array of splits for the current feature + */ + private[impl] def updateUnorderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + instanceWeight: Double): Unit = { + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) + // Each unordered split has a corresponding bin for impurity stats of data points that fall + // onto the left side of the split. For each unordered split, update left-side bin if applicable + // for the current data point. + val numSplits = agg.metadata.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight) + } + splitIndex += 1 + } + } + + /** Update aggregator for an (ordered feature, label) pair */ + private[impl] def updateOrderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndexIdx: Int, + instanceWeight: Double): Unit = { + // The bin index of an ordered feature is just the feature value itself + val binIndex = featureValue + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala new file mode 100644 index 000000000000..0dd021eec247 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Helper methods for impurity-related calculations during node split decisions. */ +private[impl] object ImpurityUtils { + + /** + * Get impurity calculator containing statistics for all labels for rows corresponding to + * feature values in [from, to). + * @param indices indices(i) = row index corresponding to ith feature value + */ + private[impl] def getParentImpurityCalculator( + metadata: DecisionTreeMetadata, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): ImpurityCalculator = { + // Compute sufficient stats (e.g. label counts) for all data at the current node, + // store result in currNodeStatsAgg.parentStats so that we can share it across + // all features for the current node + val currNodeStatsAgg = new DTStatsAggregator(metadata, featureSubset = None) + AggUpdateUtils.updateParentImpurity(currNodeStatsAgg, indices, from, to, + instanceWeights, labels) + currNodeStatsAgg.getParentImpurityCalculator() + } + + /** + * Calculate the impurity statistics for a given (feature, split) based upon left/right + * aggregates. + * + * @param parentImpurityCalculator An ImpurityCalculator containing the impurity stats + * of the node currently being split. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private[impl] def calculateImpurityStats( + parentImpurityCalculator: ImpurityCalculator, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val impurity: Double = parentImpurityCalculator.calculate() + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + // If information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + // If information gain is non-positive but doesn't violate the minimum info gain constraint, + // return a stats object with correct values but valid = false to indicate that we should not + // split. + if (gain <= 0) { + return new ImpurityStats(gain, impurity, parentImpurityCalculator, leftImpurityCalculator, + rightImpurityCalculator, valid = false) + } + + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Given an impurity aggregator containing label statistics for a given (node, feature, bin), + * returns the corresponding "centroid", used to order bins while computing best splits. + * + * @param metadata learning and dataset metadata for DecisionTree + */ + private[impl] def getCentroid( + metadata: DecisionTreeMetadata, + binStats: ImpurityCalculator): Double = { + + if (binStats.count != 0) { + if (metadata.isMulticlass) { + // multiclass classification + // For categorical features in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + binStats.calculate() + } else if (metadata.isClassification) { + // binary classification + // For categorical features in binary classification, + // the bins are ordered by the count of class 1. + binStats.stats(1) + } else { + // regression + // For categorical features in regression and binary classification, + // the bins are ordered by the prediction. + binStats.predict + } + } else { + Double.MaxValue + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553..f8c3dd7ff2e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import java.io.IOException -import scala.collection.mutable +import scala.collection.{mutable, SeqView} import scala.util.Random import org.apache.spark.internal.Logging @@ -280,23 +280,14 @@ private[spark] object RandomForest extends Logging { featureIndexIdx } if (unorderedFeatures.contains(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) - // Update the left or right bin for each split. - val numSplits = agg.metadata.numSplits(featureIndex) - val featureSplits = splits(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { - if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } - splitIndex += 1 - } + AggUpdateUtils.updateUnorderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, + featureSplits = splits(featureIndex), instanceWeight = instanceWeight) } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + AggUpdateUtils.updateOrderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndexIdx = featureIndexIdx, instanceWeight = instanceWeight) } featureIndexIdx += 1 } @@ -550,6 +541,7 @@ private[spark] object RandomForest extends Logging { } } + // Aggregate sufficient stats by node, then find best splits val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => @@ -558,12 +550,13 @@ private[spark] object RandomForest extends Logging { // find best split for each node val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + RandomForest.binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") + // Perform splits val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) @@ -627,221 +620,38 @@ private[spark] object RandomForest extends Logging { } /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) + * Return a list of pairs (featureIndexIdx, featureIndex) where featureIndex is the global + * (across all trees) index of a feature and featureIndexIdx is the index of a feature within the + * list of features for a given node. Filters out features known to be constant + * (features with 0 splits) */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + private[impl] def getFeaturesWithSplits( + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]]): SeqView[(Int, Int), Seq[_]] = { + Range(0, metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + metadata.numSplits(featureIndex) != 0 } - - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) } - /** - * Find the best split for a node. - * - * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) - */ - private[tree] def binsToBestSplit( - binAggregates: DTStatsAggregator, - splits: Array[Array[Split]], + private[impl] def getBestSplitByGain( + parentImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata, featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, ImpurityStats) = { - - // Calculate InformationGain and ImpurityStats if current node is top node - val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level == 0) { - null - } else { - node.stats - } - - val validFeatureSplits = - Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => - featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - }.withFilter { case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } - - // For each (feature, split), calculate the gain, and select the best (feature, split). - val splitsAndImpurityInfo = - validFeatureSplits.map { case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) - } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict - } - } else { - Double.MaxValue - } - (featureValue, centroid) - } - - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 - } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) - } - } - + splitsAndImpurityInfo: Seq[(Split, ImpurityStats)]): (Split, ImpurityStats) = { val (bestSplit, bestSplitStats) = if (splitsAndImpurityInfo.isEmpty) { // If no valid splits for features, then this split is invalid, // return invalid information gain stats. Take any split and continue. // Splits is empty, so arbitrarily choose to split on any threshold val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) - val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() - if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { + if (metadata.isContinuous(dummyFeatureIndex)) { (new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { - val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) + val numCategories = metadata.featureArity(dummyFeatureIndex) (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } @@ -851,6 +661,41 @@ private[spark] object RandomForest extends Logging { (bestSplit, bestSplitStats) } + /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, ImpurityStats) = { + val validFeatureSplits = getFeaturesWithSplits(binAggregates.metadata, featuresForNode) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) + val splitsAndImpurityInfo = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + SplitUtils.chooseSplit(binAggregates, featureIndex, featureIndexIdx, splits(featureIndex), + parentImpurityCalc) + } + getBestSplitByGain(binAggregates.getParentImpurityCalculator(), binAggregates.metadata, + featuresForNode, splitsAndImpurityInfo) + } + + private[impl] def findUnorderedSplits( + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Split] = { + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(featureIndex) + Array.tabulate[Split](metadata.numSplits(featureIndex)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + } + } + /** * Returns splits for decision tree calculation. * Continuous and categorical features are handled differently. @@ -936,13 +781,7 @@ private[spark] object RandomForest extends Logging { split case i if metadata.isCategorical(i) && metadata.isUnordered(i) => - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - val featureArity = metadata.featureArity(i) - Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) - new CategoricalSplit(i, categories.toArray, featureArity) - } + findUnorderedSplits(metadata, i) case i if metadata.isCategorical(i) => // Ordered features @@ -1147,4 +986,5 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala new file mode 100644 index 000000000000..206405a69305 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{CategoricalSplit, Split} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Utility methods for choosing splits during local & distributed tree training. */ +private[impl] object SplitUtils extends Logging { + + /** Sorts ordered feature categories by label centroid, returning an ordered list of categories */ + private def sortByCentroid( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int): List[Int] = { + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val numCategories = binAggregates.metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, categoryStats) + (featureValue, centroid) + } + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2).map(_._1) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + categoriesSortedByCentroid + } + + /** + * Find the best split for an unordered categorical feature at a single node. + * + * Algorithm: + * - Considers all possible subsets (exponentially many) + * + * @param featureIndex Global index of feature being split. + * @param featureIndexIdx Index of feature being split within subset of features for current node. + * @param featureSplits Array of splits for the current feature + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats instance will be invalid (have member valid = false). + */ + private[impl] def chooseUnorderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Unordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + + } + + /** + * Choose splitting rule: feature value <= threshold + * + * @return (best split, statistics for split) If the best split actually puts all instances + * in one leaf node, then it will be set to None. If no valid split was found, the + * returned ImpurityStats instance will be invalid (have member valid = false) + */ + private[impl] def chooseContinuousSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // For a continuous feature, bins are already sorted for splitting + // Number of "categories" = number of bins + val sortedCategories = Range(0, binAggregates.metadata.numBins(featureIndex)).toList + // Get & return best split info + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, sortedCategories, parentCalculator) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + } + + /** + * Computes the index of the best split for an ordered feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private def orderedSplitHelper( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + categoriesSortedByCentroid: List[Int], + parentCalculator: Option[ImpurityCalculator]): (Int, ImpurityStats) = { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex) + val nextCategory = categoriesSortedByCentroid(splitIndex + 1) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last + + // Find best split. + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + } + + /** + * Choose the best split for an ordered categorical feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private[impl] def chooseOrderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Sort feature categories by label centroid + val categoriesSortedByCentroid = sortByCentroid(binAggregates, featureIndex, featureIndexIdx) + // Get index, stats of best split + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, categoriesSortedByCentroid, parentCalculator) + // Create result (CategoricalSplit instance) + val categoriesForSplit = + categoriesSortedByCentroid.map(_.toDouble).slice(0, bestFeatureSplitIndex + 1) + val numCategories = binAggregates.metadata.featureArity(featureIndex) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + + /** + * Choose the best split for a feature at a node. + * + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats will have member stats.valid = false. + */ + private[impl] def chooseSplit( + statsAggregator: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + val metadata = statsAggregator.metadata + if (metadata.isCategorical(featureIndex)) { + if (metadata.isUnordered(featureIndex)) { + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, featureSplits, parentCalculator) + } else { + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, parentCalculator) + } + } else { + SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndexIdx, + featureSplits, parentCalculator) + } + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f3dbfd96e181..029a709f553d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -75,8 +75,9 @@ class InformationGainStats( * @param impurityCalculator impurity statistics for current node * @param leftImpurityCalculator impurity statistics for left child node * @param rightImpurityCalculator impurity statistics for right child node - * @param valid whether the current split satisfies minimum info gain or - * minimum number of instances per node + * @param valid whether the current split should be performed; true if split + * satisfies minimum info gain, minimum number of instances per node, and + * has positive info gain. */ private[spark] class ImpurityStats( val gain: Double, @@ -112,7 +113,7 @@ private[spark] object ImpurityStats { * minimum number of instances per node. */ def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { - new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + new ImpurityStats(Double.MinValue, impurity = -1, impurityCalculator, null, null, false) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala new file mode 100644 index 000000000000..290f6d78ef67 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Suite exercising helper methods for making split decisions during decision tree training. */ +class TreeSplitUtilsSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + /** + * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated + * with the data from the specified training points. Assumes a feature index of 0 and that + * all training points have the same weights (1.0). + */ + private def getAggregator( + metadata: DecisionTreeMetadata, + values: Array[Int], + labels: Array[Double], + featureSplits: Array[Split]): DTStatsAggregator = { + // Create stats aggregator + val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) + // Update parent impurity stats + val featureIndex = 0 + val instanceWeights = Array.fill[Double](values.length)(1.0) + AggUpdateUtils.updateParentImpurity(statsAggregator, indices = values.indices.toArray, + from = 0, to = values.length, instanceWeights, labels) + // Update current aggregator's impurity stats + values.zip(labels).foreach { case (value: Int, label: Double) => + if (metadata.isUnordered(featureIndex)) { + AggUpdateUtils.updateUnorderedFeature(statsAggregator, value, label, + featureIndex = featureIndex, featureIndexIdx = 0, featureSplits, instanceWeight = 1.0) + } else { + AggUpdateUtils.updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0, + instanceWeight = 1.0) + } + } + statsAggregator + } + + /** + * Check that left/right impurities match what we'd expect for a split. + * @param labels Labels whose impurity information should be reflected in stats + * @param stats ImpurityStats object containing impurity info for the left/right sides of a split + */ + private def validateImpurityStats( + impurity: Impurity, + labels: Array[Double], + stats: ImpurityStats, + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + // Compute impurity for our data points manually + val numClasses = (labels.max + 1).toInt + val fullImpurityStatsArray + = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble) + val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) + // Verify that impurity stats were computed correctly for split + assert(stats.impurityCalculator.stats === fullImpurityStatsArray) + assert(stats.impurity === fullImpurity) + assert(stats.leftImpurityCalculator.stats === expectedLeftStats) + assert(stats.rightImpurityCalculator.stats === expectedRightStats) + assert(stats.valid) + } + + /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ + + test("chooseSplit: choose correct type of split (continuous split)") { + // Construct (binned) continuous data + val labels = Array(0.0, 0.0, 1.0) + val values = Array(1, 2, 3) + val featureIndex = 0 + // Get an array of continuous splits corresponding to values in our binned data + val splits = TreeTests.getContinuousSplits(thresholds = values.distinct.sorted, + featureIndex = 0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, values, labels, splits) + // Choose split, check that it's a valid ContinuousSplit + val (split, stats) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex, + splits) + assert(stats.valid && split.isInstanceOf[ContinuousSplit]) + } + + test("chooseSplit: choose correct type of split (categorical split)") { + // Construct categorical data + val labels = Array(0.0, 0.0, 1.0, 1.0, 1.0) + val featureArity = 3 + val values = Array(0, 0, 1, 2, 2) + val featureIndex = 0 + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, values, labels, splits) + // Choose split, check that it's a valid categorical split + val (split, stats) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, + featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits) + assert(stats.valid && split.isInstanceOf[CategoricalSplit]) + } + + test("chooseOrderedCategoricalSplit: basic case") { + // Helper method for testing ordered categorical split + def testHelper( + values: Array[Int], + labels: Array[Double], + expectedLeftCategories: Array[Double], + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + // Set up metadata for ordered categorical feature + val featureIndex = 0 + val featureArity = values.max + 1 + val arityMap = Map[Int, Int](featureIndex -> featureArity) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) + // Construct DTStatsAggregator, compute sufficient stats + val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty) + // Choose split + val (split, stats) = + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex) + // Verify that split has the expected left-side/right-side categories + val expectedRightCategories = Range(0, featureArity) + .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories === expectedLeftCategories) + assert(s.rightCategories === expectedRightCategories) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats) + } + + // Test a single split: The left side of our split should contain the two points with label 0, + // the left side of our split should contain the five points with label 1 + val values = Array(0, 0, 1, 2, 2, 2, 2) + val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels1, expectedLeftCategories = Array(0.0), + expectedLeftStats = Array(2.0, 0.0), expectedRightStats = Array(0.0, 5.0)) + + // Test a single split: The left side of our split should contain the three points with label 0, + // the left side of our split should contain the four points with label 1 + val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels2, expectedLeftCategories = Array(0.0, 1.0), + expectedLeftStats = Array(3.0, 0.0), expectedRightStats = Array(0.0, 4.0)) + } + + test("chooseOrderedCategoricalSplit: return bad stats if we should not split") { + // Construct categorical data + val featureIndex = 0 + val values = Array(0, 0, 1, 2, 2, 2, 2) + val featureArity = values.max + 1 + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndex) + assert(!stats.valid) + } + + test("chooseUnorderedCategoricalSplit: basic case") { + val featureIndex = 0 + // Construct data for unordered categorical feature + // label: 0 --> values: 0, 1 + // label: 1 --> values: 2, 3 + // label: 2 --> values: 2, 2, 4 + // Expected split: feature values (0, 1) on the left, values (2, 3, 4) on the right + val values = Array(0, 1, 2, 3, 2, 2, 4) + val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val featureArity = values.max + 1 + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 3, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, values, labels, splits) + // Choose split + val (split, stats) = + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex, + splits) + // Verify that split has the expected left-side/right-side categories + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories.toSet === Set(0.0, 1.0)) + assert(s.rightCategories.toSet === Set(2.0, 3.0, 4.0)) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0, 0.0), + expectedRightStats = Array(0.0, 2.0, 3.0)) + } + + test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") { + // Construct data for unordered categorical feature; all points have label 1 + val featureIndex = 0 + val featureArity = 4 + val values = Array(3, 1, 0, 2, 2) + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, values, labels, splits) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } + + test("chooseContinuousSplit: basic case") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 1.0, 1.0) + // Construct DTStatsAggregator, compute sufficient stats + val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, values, labels, splits) + + // Choose split, verify that it has expected threshold + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + split match { + case s: ContinuousSplit => + assert(s.featureIndex === featureIndex) + assert(s.threshold === 1) + case _ => + throw new AssertionError( + s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") + } + // Verify impurity stats of split + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0), + expectedRightStats = Array(0.0, 2.0)) + } + + test("chooseContinuousSplit: return bad stats if we should not split") { + // Construct data for continuous feature; all points have label 0 + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0) + // Construct DTStatsAggregator, compute sufficient stats + val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty[Int, Int]) + val statsAggregator = getAggregator(metadata, values, labels, splits) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2..d16ca5432cf0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} @@ -101,6 +102,48 @@ private[ml] object TreeTests extends SparkFunSuite { data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata)) } + /** Returns a DecisionTreeMetadata instance with hard-coded values for use in tests */ + def getMetadata( + numExamples: Int, + numFeatures: Int, + numClasses: Int, + featureArity: Map[Int, Int], + impurity: Impurity = Entropy, + unorderedFeatures: Option[Set[Int]] = None): DecisionTreeMetadata = { + // By default, assume all categorical features within tests + // have small enough arity to be treated as unordered + val unordered = unorderedFeatures.getOrElse(featureArity.keys.toSet) + + // Set numBins appropriately for categorical features + val maxBins = 4 + val numBins: Array[Int] = 0.until(numFeatures).toArray.map { featureIndex => + if (featureArity.contains(featureIndex) && featureArity(featureIndex) > 0) { + featureArity(featureIndex) + } else { + maxBins + } + } + + new DecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples, + numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity, + unorderedFeatures = unordered, numBins = numBins, impurity = impurity, + quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1, + numFeaturesPerNode = 2) + } + + /** + * Returns an array of continuous splits for the feature with index featureIndex and the passed-in + * set of threshold. Creates one continuous split per threshold in thresholds. + */ + private[impl] def getContinuousSplits( + thresholds: Array[Int], + featureIndex: Int): Array[Split] = { + val splits = thresholds.sorted.map { + new ContinuousSplit(featureIndex, _).asInstanceOf[Split] + } + splits + } + /** * Check if the two trees are exactly the same. * Note: I hesitate to override Node.equals since it could cause problems if users