From f2e3fbd40eea2919d249710eae5b5789d97543b7 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 09:52:01 -0800 Subject: [PATCH 1/8] Local tree training part 1 (refactor RandomForest.scala into utility classes) --- .../scala/org/apache/spark/ml/tree/Node.scala | 14 +- .../spark/ml/tree/impl/AggUpdateUtils.scala | 85 +++++ .../spark/ml/tree/impl/ImpurityUtils.scala | 135 ++++++++ .../spark/ml/tree/impl/RandomForest.scala | 292 ++++-------------- .../spark/ml/tree/impl/SplitUtils.scala | 206 ++++++++++++ .../tree/model/InformationGainStats.scala | 7 +- 6 files changed, 501 insertions(+), 238 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10..17aba54f21bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -276,14 +276,10 @@ private[tree] class LearningNode( new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, - stats.impurityCalculator) - } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) - } - + assert(stats != null, "Unknown error during Decision Tree learning. Could not convert " + + "LearningNode to Node") + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) } } @@ -334,7 +330,7 @@ private[tree] object LearningNode { id: Int, isLeaf: Boolean, stats: ImpurityStats): LearningNode = { - new LearningNode(id, None, None, None, false, stats) + new LearningNode(id, None, None, None, isLeaf, stats) } /** Create an empty node with the given node index. Values must be set later on. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala new file mode 100644 index 000000000000..07e4a16e2990 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.Split + +/** + * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. + */ +private[impl] object AggUpdateUtils { + + /** + * Updates the parent node stats of the passed-in impurity aggregator with the labels + * corresponding to the feature values at indices [from, to). + * @param indices Array of row indices for feature values; indices(i) = row index of the ith + * feature value + */ + private[impl] def updateParentImpurity( + statsAggregator: DTStatsAggregator, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): Unit = { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + val label = labels(rowIndex) + statsAggregator.updateParent(label, instanceWeights(rowIndex)) + } + } + + /** + * Update aggregator for an (unordered feature, label) pair + * @param featureSplits Array of splits for the current feature + */ + private[impl] def updateUnorderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + instanceWeight: Double): Unit = { + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) + // Each unordered split has a corresponding bin for impurity stats of data points that fall + // onto the left side of the split. For each unordered split, update left-side bin if applicable + // for the current data point. + val numSplits = agg.metadata.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight) + } + splitIndex += 1 + } + } + + /** Update aggregator for an (ordered feature, label) pair */ + private[impl] def updateOrderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndexIdx: Int, + instanceWeight: Double): Unit = { + // The bin index of an ordered feature is just the feature value itself + val binIndex = featureValue + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala new file mode 100644 index 000000000000..0dd021eec247 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Helper methods for impurity-related calculations during node split decisions. */ +private[impl] object ImpurityUtils { + + /** + * Get impurity calculator containing statistics for all labels for rows corresponding to + * feature values in [from, to). + * @param indices indices(i) = row index corresponding to ith feature value + */ + private[impl] def getParentImpurityCalculator( + metadata: DecisionTreeMetadata, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): ImpurityCalculator = { + // Compute sufficient stats (e.g. label counts) for all data at the current node, + // store result in currNodeStatsAgg.parentStats so that we can share it across + // all features for the current node + val currNodeStatsAgg = new DTStatsAggregator(metadata, featureSubset = None) + AggUpdateUtils.updateParentImpurity(currNodeStatsAgg, indices, from, to, + instanceWeights, labels) + currNodeStatsAgg.getParentImpurityCalculator() + } + + /** + * Calculate the impurity statistics for a given (feature, split) based upon left/right + * aggregates. + * + * @param parentImpurityCalculator An ImpurityCalculator containing the impurity stats + * of the node currently being split. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private[impl] def calculateImpurityStats( + parentImpurityCalculator: ImpurityCalculator, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val impurity: Double = parentImpurityCalculator.calculate() + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + // If information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + // If information gain is non-positive but doesn't violate the minimum info gain constraint, + // return a stats object with correct values but valid = false to indicate that we should not + // split. + if (gain <= 0) { + return new ImpurityStats(gain, impurity, parentImpurityCalculator, leftImpurityCalculator, + rightImpurityCalculator, valid = false) + } + + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Given an impurity aggregator containing label statistics for a given (node, feature, bin), + * returns the corresponding "centroid", used to order bins while computing best splits. + * + * @param metadata learning and dataset metadata for DecisionTree + */ + private[impl] def getCentroid( + metadata: DecisionTreeMetadata, + binStats: ImpurityCalculator): Double = { + + if (binStats.count != 0) { + if (metadata.isMulticlass) { + // multiclass classification + // For categorical features in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + binStats.calculate() + } else if (metadata.isClassification) { + // binary classification + // For categorical features in binary classification, + // the bins are ordered by the count of class 1. + binStats.stats(1) + } else { + // regression + // For categorical features in regression and binary classification, + // the bins are ordered by the prediction. + binStats.predict + } + } else { + Double.MaxValue + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553..f8c3dd7ff2e7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import java.io.IOException -import scala.collection.mutable +import scala.collection.{mutable, SeqView} import scala.util.Random import org.apache.spark.internal.Logging @@ -280,23 +280,14 @@ private[spark] object RandomForest extends Logging { featureIndexIdx } if (unorderedFeatures.contains(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) - // Update the left or right bin for each split. - val numSplits = agg.metadata.numSplits(featureIndex) - val featureSplits = splits(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { - if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } - splitIndex += 1 - } + AggUpdateUtils.updateUnorderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, + featureSplits = splits(featureIndex), instanceWeight = instanceWeight) } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + AggUpdateUtils.updateOrderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndexIdx = featureIndexIdx, instanceWeight = instanceWeight) } featureIndexIdx += 1 } @@ -550,6 +541,7 @@ private[spark] object RandomForest extends Logging { } } + // Aggregate sufficient stats by node, then find best splits val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => @@ -558,12 +550,13 @@ private[spark] object RandomForest extends Logging { // find best split for each node val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + RandomForest.binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") + // Perform splits val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) @@ -627,221 +620,38 @@ private[spark] object RandomForest extends Logging { } /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) + * Return a list of pairs (featureIndexIdx, featureIndex) where featureIndex is the global + * (across all trees) index of a feature and featureIndexIdx is the index of a feature within the + * list of features for a given node. Filters out features known to be constant + * (features with 0 splits) */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + private[impl] def getFeaturesWithSplits( + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]]): SeqView[(Int, Int), Seq[_]] = { + Range(0, metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + metadata.numSplits(featureIndex) != 0 } - - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) } - /** - * Find the best split for a node. - * - * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) - */ - private[tree] def binsToBestSplit( - binAggregates: DTStatsAggregator, - splits: Array[Array[Split]], + private[impl] def getBestSplitByGain( + parentImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata, featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, ImpurityStats) = { - - // Calculate InformationGain and ImpurityStats if current node is top node - val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level == 0) { - null - } else { - node.stats - } - - val validFeatureSplits = - Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => - featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - }.withFilter { case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } - - // For each (feature, split), calculate the gain, and select the best (feature, split). - val splitsAndImpurityInfo = - validFeatureSplits.map { case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) - } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict - } - } else { - Double.MaxValue - } - (featureValue, centroid) - } - - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 - } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) - } - } - + splitsAndImpurityInfo: Seq[(Split, ImpurityStats)]): (Split, ImpurityStats) = { val (bestSplit, bestSplitStats) = if (splitsAndImpurityInfo.isEmpty) { // If no valid splits for features, then this split is invalid, // return invalid information gain stats. Take any split and continue. // Splits is empty, so arbitrarily choose to split on any threshold val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) - val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() - if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { + if (metadata.isContinuous(dummyFeatureIndex)) { (new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { - val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) + val numCategories = metadata.featureArity(dummyFeatureIndex) (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } @@ -851,6 +661,41 @@ private[spark] object RandomForest extends Logging { (bestSplit, bestSplitStats) } + /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, ImpurityStats) = { + val validFeatureSplits = getFeaturesWithSplits(binAggregates.metadata, featuresForNode) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) + val splitsAndImpurityInfo = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + SplitUtils.chooseSplit(binAggregates, featureIndex, featureIndexIdx, splits(featureIndex), + parentImpurityCalc) + } + getBestSplitByGain(binAggregates.getParentImpurityCalculator(), binAggregates.metadata, + featuresForNode, splitsAndImpurityInfo) + } + + private[impl] def findUnorderedSplits( + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Split] = { + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(featureIndex) + Array.tabulate[Split](metadata.numSplits(featureIndex)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + } + } + /** * Returns splits for decision tree calculation. * Continuous and categorical features are handled differently. @@ -936,13 +781,7 @@ private[spark] object RandomForest extends Logging { split case i if metadata.isCategorical(i) && metadata.isUnordered(i) => - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - val featureArity = metadata.featureArity(i) - Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) - new CategoricalSplit(i, categories.toArray, featureArity) - } + findUnorderedSplits(metadata, i) case i if metadata.isCategorical(i) => // Ordered features @@ -1147,4 +986,5 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala new file mode 100644 index 000000000000..206405a69305 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{CategoricalSplit, Split} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Utility methods for choosing splits during local & distributed tree training. */ +private[impl] object SplitUtils extends Logging { + + /** Sorts ordered feature categories by label centroid, returning an ordered list of categories */ + private def sortByCentroid( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int): List[Int] = { + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val numCategories = binAggregates.metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, categoryStats) + (featureValue, centroid) + } + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2).map(_._1) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + categoriesSortedByCentroid + } + + /** + * Find the best split for an unordered categorical feature at a single node. + * + * Algorithm: + * - Considers all possible subsets (exponentially many) + * + * @param featureIndex Global index of feature being split. + * @param featureIndexIdx Index of feature being split within subset of features for current node. + * @param featureSplits Array of splits for the current feature + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats instance will be invalid (have member valid = false). + */ + private[impl] def chooseUnorderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Unordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + + } + + /** + * Choose splitting rule: feature value <= threshold + * + * @return (best split, statistics for split) If the best split actually puts all instances + * in one leaf node, then it will be set to None. If no valid split was found, the + * returned ImpurityStats instance will be invalid (have member valid = false) + */ + private[impl] def chooseContinuousSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // For a continuous feature, bins are already sorted for splitting + // Number of "categories" = number of bins + val sortedCategories = Range(0, binAggregates.metadata.numBins(featureIndex)).toList + // Get & return best split info + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, sortedCategories, parentCalculator) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + } + + /** + * Computes the index of the best split for an ordered feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private def orderedSplitHelper( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + categoriesSortedByCentroid: List[Int], + parentCalculator: Option[ImpurityCalculator]): (Int, ImpurityStats) = { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex) + val nextCategory = categoriesSortedByCentroid(splitIndex + 1) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last + + // Find best split. + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + } + + /** + * Choose the best split for an ordered categorical feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private[impl] def chooseOrderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Sort feature categories by label centroid + val categoriesSortedByCentroid = sortByCentroid(binAggregates, featureIndex, featureIndexIdx) + // Get index, stats of best split + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, categoriesSortedByCentroid, parentCalculator) + // Create result (CategoricalSplit instance) + val categoriesForSplit = + categoriesSortedByCentroid.map(_.toDouble).slice(0, bestFeatureSplitIndex + 1) + val numCategories = binAggregates.metadata.featureArity(featureIndex) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + + /** + * Choose the best split for a feature at a node. + * + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats will have member stats.valid = false. + */ + private[impl] def chooseSplit( + statsAggregator: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + val metadata = statsAggregator.metadata + if (metadata.isCategorical(featureIndex)) { + if (metadata.isUnordered(featureIndex)) { + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, featureSplits, parentCalculator) + } else { + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, parentCalculator) + } + } else { + SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndexIdx, + featureSplits, parentCalculator) + } + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f3dbfd96e181..029a709f553d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -75,8 +75,9 @@ class InformationGainStats( * @param impurityCalculator impurity statistics for current node * @param leftImpurityCalculator impurity statistics for left child node * @param rightImpurityCalculator impurity statistics for right child node - * @param valid whether the current split satisfies minimum info gain or - * minimum number of instances per node + * @param valid whether the current split should be performed; true if split + * satisfies minimum info gain, minimum number of instances per node, and + * has positive info gain. */ private[spark] class ImpurityStats( val gain: Double, @@ -112,7 +113,7 @@ private[spark] object ImpurityStats { * minimum number of instances per node. */ def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { - new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + new ImpurityStats(Double.MinValue, impurity = -1, impurityCalculator, null, null, false) } From a2357c95672e94a148051d00e26b89245eb8e204 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 09:57:55 -0800 Subject: [PATCH 2/8] WIP adding TreeSplitUtilsSuite --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala new file mode 100644 index 000000000000..f3e5e05d040e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Suite exercising helper methods for making split decisions during decision tree training. */ +class TreeSplitUtilsSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + /** + * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated + * with the data from the specified training points. + */ + private def getAggregator( + metadata: DecisionTreeMetadata, + values: Array[Int], + from: Int, + to: Int, + labels: Array[Double], + featureSplits: Array[Split]): DTStatsAggregator = { + + val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) + val instanceWeights = Array.fill[Double](values.length)(1.0) + val indices = values.indices.toArray + AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) + LocalDecisionTree.updateAggregator(statsAggregator, col, indices, instanceWeights, labels, + from, to, col.featureIndex, featureSplits) + statsAggregator + } + + /** Check that left/right impurities match what we'd expect for a split. */ + private def validateImpurityStats( + impurity: Impurity, + labels: Array[Double], + stats: ImpurityStats, + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + // Verify that impurity stats were computed correctly for split + val numClasses = (labels.max + 1).toInt + val fullImpurityStatsArray + = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble) + val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) + assert(stats.impurityCalculator.stats === fullImpurityStatsArray) + assert(stats.impurity === fullImpurity) + assert(stats.leftImpurityCalculator.stats === expectedLeftStats) + assert(stats.rightImpurityCalculator.stats === expectedRightStats) + assert(stats.valid) + } + + /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ + + test("chooseSplit: choose correct type of split (continuous split)") { + // Construct (binned) continuous data + val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6)) + // Get an array of continuous splits corresponding to values in our binned data + val splits = LocalTreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = 7, + numFeatures = 1, numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + // Choose split, check that it's a valid ContinuousSplit + val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, col.featureIndex, + col.featureIndex, splits) + assert(stats1.valid && split1.isInstanceOf[ContinuousSplit]) + } + + test("chooseSplit: choose correct type of split (categorical split)") { + // Construct categorical data + val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val featureIndex = 0 + val featureArity = 3 + val values = Array(0, 0, 1, 1, 1, 2, 2) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = 7, + numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + // Choose split, check that it's a valid categorical split + val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, + featureIndex = col.featureIndex, featureIndexIdx = col.featureIndex, + featureSplits = splits) + assert(stats2.valid && split2.isInstanceOf[CategoricalSplit]) + } + + test("chooseOrderedCategoricalSplit: basic case") { + // Helper method for testing ordered categorical split + def testHelper( + values: Array[Int], + labels: Array[Double], + expectedLeftCategories: Array[Double], + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + val featureIndex = 0 + // Construct FeatureVector to store categorical data + val featureArity = values.max + 1 + val arityMap = Map[Int, Int](featureIndex -> featureArity) + val col = FeatureColumn(featureIndex = 0, values = values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, featureSplits = Array.empty) + // Choose split + val (split, stats) = + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex) + // Verify that split has the expected left-side/right-side categories + val expectedRightCategories = Range(0, featureArity) + .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories === expectedLeftCategories) + assert(s.rightCategories === expectedRightCategories) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats) + } + + val values = Array(0, 0, 1, 2, 2, 2, 2) + val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0)) + + val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0)) + } + + test("chooseOrderedCategoricalSplit: return bad stats if we should not split") { + // Construct categorical data + val featureIndex = 0 + val values = Array(0, 0, 1, 2, 2, 2, 2) + val featureArity = values.max + 1 + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, featureSplits = Array.empty) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex) + assert(!stats.valid) + } + + test("chooseUnorderedCategoricalSplit: basic case") { + val featureIndex = 0 + // Construct data for unordered categorical feature + // label: 0 --> values: 1 + // label: 1 --> values: 0, 2 + // label: 2 --> values: 2 + val values = Array(1, 1, 0, 2, 2) + val featureArity = values.max + 1 + val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 3, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, splits) + // Choose split + val (split, stats) = + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex, splits) + // Verify that split has the expected left-side/right-side categories + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories.toSet === Set(1.0)) + assert(s.rightCategories.toSet === Set(0.0, 2.0)) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0, 0.0), + expectedRightStats = Array(0.0, 2.0, 1.0)) + } + + test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") { + // Construct data for unordered categorical feature + val featureIndex = 0 + val featureArity = 4 + val values = Array(3, 1, 0, 2, 2) + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } + + test("chooseContinuousSplit: basic case") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = featureIndex, values = values) + + // Construct DTStatsAggregator, compute sufficient stats + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + + // Choose split, verify that it has expected threshold + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + split match { + case s: ContinuousSplit => + assert(s.featureIndex === featureIndex) + assert(s.threshold === 1) + case _ => + throw new AssertionError( + s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") + } + // Verify impurity stats of split + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0), + expectedRightStats = Array(0.0, 2.0)) + } + + test("chooseContinuousSplit: return bad stats if we should not split") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0) + val col = FeatureColumn(featureIndex = featureIndex, values = values) + // Construct DTStatsAggregator, compute sufficient stats + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty[Int, Int]) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + // Choose split, verify that it's invalid + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } +} From 320c32ee8d0ac9bde457b0286d064470648c73af Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 11:37:56 -0800 Subject: [PATCH 3/8] WIP --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 6 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 63 +++++++++++++++++++ 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala index f3e5e05d040e..069e8c5ec2ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -75,7 +75,7 @@ class TreeSplitUtilsSuite val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6)) // Get an array of continuous splits corresponding to values in our binned data - val splits = LocalTreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) + val splits = TreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = 7, numFeatures = 1, numClasses = 2, Map.empty) @@ -229,7 +229,7 @@ class TreeSplitUtilsSuite val col = FeatureColumn(featureIndex = featureIndex, values = values) // Construct DTStatsAggregator, compute sufficient stats - val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty) val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) @@ -258,7 +258,7 @@ class TreeSplitUtilsSuite val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0) val col = FeatureColumn(featureIndex = featureIndex, values = values) // Construct DTStatsAggregator, compute sufficient stats - val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty[Int, Int]) val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2..1013841745ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} @@ -101,6 +102,48 @@ private[ml] object TreeTests extends SparkFunSuite { data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata)) } + /** Returns a DecisionTreeMetadata instance with hard-coded values for use in tests */ + def getMetadata( + numExamples: Int, + numFeatures: Int, + numClasses: Int, + featureArity: Map[Int, Int], + impurity: Impurity = Entropy, + unorderedFeatures: Option[Set[Int]] = None): DecisionTreeMetadata = { + // By default, assume all categorical features within tests + // have small enough arity to be treated as unordered + val unordered = unorderedFeatures.getOrElse(featureArity.keys.toSet) + + // Set numBins appropriately for categorical features + val maxBins = 4 + val numBins: Array[Int] = 0.until(numFeatures).toArray.map { featureIndex => + if (featureArity.contains(featureIndex) && featureArity(featureIndex) > 0) { + featureArity(featureIndex) + } else { + maxBins + } + } + + new DecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples, + numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity, + unorderedFeatures = unordered, numBins = numBins, impurity = impurity, + quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1, + numFeaturesPerNode = 2) + } + + /** + * Returns an array of continuous splits for the feature with index featureIndex and the passed-in + * set of values. Creates one continuous split per value in values. + */ + private[impl] def getContinuousSplits( + values: Array[Int], + featureIndex: Int): Array[Split] = { + val splits = values.sorted.map { + new ContinuousSplit(featureIndex, _).asInstanceOf[Split] + } + splits + } + /** * Check if the two trees are exactly the same. * Note: I hesitate to override Node.equals since it could cause problems if users @@ -194,6 +237,26 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(14.0, Vectors.dense(Array(5.0))) )) + /** + * Create toy data that can be used for testing deep tree training; the generated data requires + * [[depth]] splits to split fully. Thus a tree fit on the generated data should have a depth of + * [[depth]] (unless splitting halts early due to other constraints e.g. max depth or min + * info gain). + */ + def deepTreeData(sc: SparkContext, depth: Int): RDD[LabeledPoint] = { + // Create a dataset with [[depth]] binary features; a training point has a label of 1 + // iff all features have a value of 1. + sc.parallelize(Range(0, depth + 1).map { idx => + val features = Array.fill[Double](depth)(1) + if (idx == depth) { + LabeledPoint(1.0, Vectors.dense(features)) + } else { + features(idx) = 0.0 + LabeledPoint(0.0, Vectors.dense(features)) + } + }) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. From b93f9f3da9cca0887c0264162f5b032f14fa87d7 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 11:57:25 -0800 Subject: [PATCH 4/8] Add TreeSplitUtilsSuite, refactor it to not depend on any local tree training code --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 83 ++++++++++++------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala index 069e8c5ec2ad..b8d8b8d88b60 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -28,6 +28,35 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class TreeSplitUtilsSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + /** + * Iterate over feature values and labels for a specific (node, feature), updating stats + * aggregator for the current node. + */ + private[impl] def updateAggregator( + statsAggregator: DTStatsAggregator, + featureIndex: Int, + values: Array[Int], + indices: Array[Int], + instanceWeights: Array[Double], + labels: Array[Double], + from: Int, + to: Int, + featureIndexIdx: Int, + featureSplits: Array[Split]): Unit = { + val metadata = statsAggregator.metadata + from.until(to).foreach { idx => + val rowIndex = indices(idx) + if (metadata.isUnordered(featureIndex)) { + AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex), + featureIndex = featureIndex, featureIndexIdx, featureSplits, + instanceWeight = instanceWeights(rowIndex)) + } else { + AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex), + featureIndexIdx, instanceWeight = instanceWeights(rowIndex)) + } + } + } + /** * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated * with the data from the specified training points. @@ -40,12 +69,13 @@ class TreeSplitUtilsSuite labels: Array[Double], featureSplits: Array[Split]): DTStatsAggregator = { + val featureIndex = 0 val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) val instanceWeights = Array.fill[Double](values.length)(1.0) val indices = values.indices.toArray AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) - LocalDecisionTree.updateAggregator(statsAggregator, col, indices, instanceWeights, labels, - from, to, col.featureIndex, featureSplits) + updateAggregator(statsAggregator, featureIndex = 0, values, indices, instanceWeights, labels, + from, to, featureIndex, featureSplits) statsAggregator } @@ -73,34 +103,34 @@ class TreeSplitUtilsSuite test("chooseSplit: choose correct type of split (continuous split)") { // Construct (binned) continuous data val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) - val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6)) + val values = Array(8, 1, 1, 2, 3, 5, 6) + val featureIndex = 0 // Get an array of continuous splits corresponding to values in our binned data val splits = TreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = 7, numFeatures = 1, numClasses = 2, Map.empty) - val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits) // Choose split, check that it's a valid ContinuousSplit - val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, col.featureIndex, - col.featureIndex, splits) + val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex, + splits) assert(stats1.valid && split1.isInstanceOf[ContinuousSplit]) } test("chooseSplit: choose correct type of split (categorical split)") { // Construct categorical data val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) - val featureIndex = 0 val featureArity = 3 val values = Array(0, 0, 1, 1, 1, 2, 2) - val col = FeatureColumn(featureIndex, values) + val featureIndex = 0 // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = 7, numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits) // Choose split, check that it's a valid categorical split val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, - featureIndex = col.featureIndex, featureIndexIdx = col.featureIndex, + featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits) assert(stats2.valid && split2.isInstanceOf[CategoricalSplit]) } @@ -117,16 +147,14 @@ class TreeSplitUtilsSuite // Construct FeatureVector to store categorical data val featureArity = values.max + 1 val arityMap = Map[Int, Int](featureIndex -> featureArity) - val col = FeatureColumn(featureIndex = 0, values = values) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, featureSplits = Array.empty) // Choose split val (split, stats) = - SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, - col.featureIndex) + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex) // Verify that split has the expected left-side/right-side categories val expectedRightCategories = Range(0, featureArity) .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray @@ -156,15 +184,14 @@ class TreeSplitUtilsSuite val values = Array(0, 0, 1, 2, 2, 2, 2) val featureArity = values.max + 1 val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) - val col = FeatureColumn(featureIndex, values) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty)) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, featureSplits = Array.empty) // Choose split, verify that it's invalid - val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, - col.featureIndex) + val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndex) assert(!stats.valid) } @@ -177,17 +204,16 @@ class TreeSplitUtilsSuite val values = Array(1, 1, 0, 2, 2) val featureArity = values.max + 1 val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) - val col = FeatureColumn(featureIndex, values) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 3, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, splits) // Choose split val (split, stats) = - SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, col.featureIndex, - col.featureIndex, splits) + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex, + splits) // Verify that split has the expected left-side/right-side categories split match { case s: CategoricalSplit => @@ -208,12 +234,12 @@ class TreeSplitUtilsSuite val featureArity = 4 val values = Array(3, 1, 0, 2, 2) val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0) - val col = FeatureColumn(featureIndex, values) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, + splits) // Choose split, verify that it's invalid val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex, splits) @@ -226,13 +252,12 @@ class TreeSplitUtilsSuite val thresholds = Array(0, 1, 2, 3) val values = thresholds.indices.toArray val labels = Array(0.0, 0.0, 1.0, 1.0) - val col = FeatureColumn(featureIndex = featureIndex, values = values) - // Construct DTStatsAggregator, compute sufficient stats val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, + splits) // Choose split, verify that it has expected threshold val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, @@ -256,12 +281,12 @@ class TreeSplitUtilsSuite val thresholds = Array(0, 1, 2, 3) val values = thresholds.indices.toArray val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0) - val col = FeatureColumn(featureIndex = featureIndex, values = values) // Construct DTStatsAggregator, compute sufficient stats val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty[Int, Int]) - val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, + splits) // Choose split, verify that it's invalid val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndex, splits) From b4a5f3b2204e8ab957864103424c269cbb1da81b Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 12:08:59 -0800 Subject: [PATCH 5/8] Remove deep tree test method from TreeTests.scala --- .../apache/spark/ml/tree/impl/TreeTests.scala | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 1013841745ed..0f26f7bfeed6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -237,26 +237,6 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(14.0, Vectors.dense(Array(5.0))) )) - /** - * Create toy data that can be used for testing deep tree training; the generated data requires - * [[depth]] splits to split fully. Thus a tree fit on the generated data should have a depth of - * [[depth]] (unless splitting halts early due to other constraints e.g. max depth or min - * info gain). - */ - def deepTreeData(sc: SparkContext, depth: Int): RDD[LabeledPoint] = { - // Create a dataset with [[depth]] binary features; a training point has a label of 1 - // iff all features have a value of 1. - sc.parallelize(Range(0, depth + 1).map { idx => - val features = Array.fill[Double](depth)(1) - if (idx == depth) { - LabeledPoint(1.0, Vectors.dense(features)) - } else { - features(idx) = 0.0 - LabeledPoint(0.0, Vectors.dense(features)) - } - }) - } - /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load. From 31ef80bb2ee91dd46561b87a52fb65addfa81059 Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 12:15:59 -0800 Subject: [PATCH 6/8] WIP simplifying TreeSplitUtilsSuite --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 49 +++++++------------ 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala index b8d8b8d88b60..d5b27dfa7a8a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -28,35 +28,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class TreeSplitUtilsSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - /** - * Iterate over feature values and labels for a specific (node, feature), updating stats - * aggregator for the current node. - */ - private[impl] def updateAggregator( - statsAggregator: DTStatsAggregator, - featureIndex: Int, - values: Array[Int], - indices: Array[Int], - instanceWeights: Array[Double], - labels: Array[Double], - from: Int, - to: Int, - featureIndexIdx: Int, - featureSplits: Array[Split]): Unit = { - val metadata = statsAggregator.metadata - from.until(to).foreach { idx => - val rowIndex = indices(idx) - if (metadata.isUnordered(featureIndex)) { - AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex), - featureIndex = featureIndex, featureIndexIdx, featureSplits, - instanceWeight = instanceWeights(rowIndex)) - } else { - AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex), - featureIndexIdx, instanceWeight = instanceWeights(rowIndex)) - } - } - } - /** * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated * with the data from the specified training points. @@ -71,11 +42,25 @@ class TreeSplitUtilsSuite val featureIndex = 0 val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) - val instanceWeights = Array.fill[Double](values.length)(1.0) val indices = values.indices.toArray + val instanceWeights = Array.fill[Double](values.length)(1.0) + // Update parent impurity stats AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) - updateAggregator(statsAggregator, featureIndex = 0, values, indices, instanceWeights, labels, - from, to, featureIndex, featureSplits) + // Update current aggregator's impurity stats + from.until(to).foreach { idx => + val rowIndex = indices(idx) + if (metadata.isUnordered(featureIndex)) { + AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex), + featureIndex = featureIndex, featureIndexIdx, featureSplits, + instanceWeight = 1.0) + } else { + AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex), + featureIndexIdx, instanceWeight = 1.0) + } + } + + updateAggregator(statsAggregator, featureIndex = 0, featureIndexIdx = 0, values, indices, + labels, from, to, featureSplits) statsAggregator } From b6291e1dd9670b74f3643c5bdc10b5bc7d79e66c Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Wed, 15 Nov 2017 13:51:35 -0800 Subject: [PATCH 7/8] Simplify TreeSplitUtilsSuite --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 112 +++++++++--------- .../apache/spark/ml/tree/impl/TreeTests.scala | 6 +- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala index d5b27dfa7a8a..756c09f3b99a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -30,52 +30,51 @@ class TreeSplitUtilsSuite /** * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated - * with the data from the specified training points. + * with the data from the specified training points. Assumes a feature index of 0 and that + * all training points have the same weights (1.0). */ private def getAggregator( metadata: DecisionTreeMetadata, values: Array[Int], - from: Int, - to: Int, labels: Array[Double], featureSplits: Array[Split]): DTStatsAggregator = { - - val featureIndex = 0 + // Create stats aggregator val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) - val indices = values.indices.toArray - val instanceWeights = Array.fill[Double](values.length)(1.0) // Update parent impurity stats - AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) + val featureIndex = 0 + val instanceWeights = Array.fill[Double](values.length)(1.0) + AggUpdateUtils.updateParentImpurity(statsAggregator, indices = values.indices.toArray, + from = 0, to = values.length, instanceWeights, labels) // Update current aggregator's impurity stats - from.until(to).foreach { idx => - val rowIndex = indices(idx) + values.zip(labels).foreach { case (value: Int, label: Double) => if (metadata.isUnordered(featureIndex)) { - AggUpdateUtils.updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex), - featureIndex = featureIndex, featureIndexIdx, featureSplits, - instanceWeight = 1.0) + AggUpdateUtils.updateUnorderedFeature(statsAggregator, value, label, + featureIndex = featureIndex, featureIndexIdx = 0, featureSplits, instanceWeight = 1.0) } else { - AggUpdateUtils.updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex), - featureIndexIdx, instanceWeight = 1.0) + AggUpdateUtils.updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0, + instanceWeight = 1.0) } } - - updateAggregator(statsAggregator, featureIndex = 0, featureIndexIdx = 0, values, indices, - labels, from, to, featureSplits) statsAggregator } - /** Check that left/right impurities match what we'd expect for a split. */ + /** + * Check that left/right impurities match what we'd expect for a split. + * @param labels Labels whose impurity information should be reflected in stats + * @param stats ImpurityStats object containing impurity info for the left/right sides of a split + */ private def validateImpurityStats( impurity: Impurity, labels: Array[Double], stats: ImpurityStats, expectedLeftStats: Array[Double], expectedRightStats: Array[Double]): Unit = { - // Verify that impurity stats were computed correctly for split + // Compute impurity for our data points manually val numClasses = (labels.max + 1).toInt val fullImpurityStatsArray = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble) val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) + // Verify that impurity stats were computed correctly for split assert(stats.impurityCalculator.stats === fullImpurityStatsArray) assert(stats.impurity === fullImpurity) assert(stats.leftImpurityCalculator.stats === expectedLeftStats) @@ -87,37 +86,37 @@ class TreeSplitUtilsSuite test("chooseSplit: choose correct type of split (continuous split)") { // Construct (binned) continuous data - val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) - val values = Array(8, 1, 1, 2, 3, 5, 6) + val labels = Array(0.0, 0.0, 1.0) + val values = Array(1, 2, 3) val featureIndex = 0 // Get an array of continuous splits corresponding to values in our binned data - val splits = TreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) + val splits = TreeTests.getContinuousSplits(thresholds = values.distinct.sorted, + featureIndex = 0) // Construct DTStatsAggregator, compute sufficient stats - val metadata = TreeTests.getMetadata(numExamples = 7, - numFeatures = 1, numClasses = 2, Map.empty) - val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split, check that it's a valid ContinuousSplit - val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex, + val (split, stats) = SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndex, splits) - assert(stats1.valid && split1.isInstanceOf[ContinuousSplit]) + assert(stats.valid && split.isInstanceOf[ContinuousSplit]) } test("chooseSplit: choose correct type of split (categorical split)") { // Construct categorical data - val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val labels = Array(0.0, 0.0, 1.0, 1.0, 1.0) val featureArity = 3 - val values = Array(0, 0, 1, 1, 1, 2, 2) + val values = Array(0, 0, 1, 2, 2) val featureIndex = 0 // Construct DTStatsAggregator, compute sufficient stats - val metadata = TreeTests.getMetadata(numExamples = 7, - numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, values, from = 1, to = 4, labels, splits) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split, check that it's a valid categorical split - val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, - featureIndex = featureIndex, featureIndexIdx = featureIndex, - featureSplits = splits) - assert(stats2.valid && split2.isInstanceOf[CategoricalSplit]) + val (split, stats) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, + featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits) + assert(stats.valid && split.isInstanceOf[CategoricalSplit]) } test("chooseOrderedCategoricalSplit: basic case") { @@ -128,15 +127,14 @@ class TreeSplitUtilsSuite expectedLeftCategories: Array[Double], expectedLeftStats: Array[Double], expectedRightStats: Array[Double]): Unit = { + // Set up metadata for ordered categorical feature val featureIndex = 0 - // Construct FeatureVector to store categorical data val featureArity = values.max + 1 val arityMap = Map[Int, Int](featureIndex -> featureArity) - // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, - labels, featureSplits = Array.empty) + // Construct DTStatsAggregator, compute sufficient stats + val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty) // Choose split val (split, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex) @@ -155,12 +153,18 @@ class TreeSplitUtilsSuite validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats) } + // Test a single split: The left side of our split should contain the two points with label 0, + // the left side of our split should contain the five points with label 1 val values = Array(0, 0, 1, 2, 2, 2, 2) val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble) - testHelper(values, labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0)) + testHelper(values, labels1, expectedLeftCategories = Array(0.0), + expectedLeftStats = Array(2.0, 0.0), expectedRightStats = Array(0.0, 5.0)) + // Test a single split: The left side of our split should contain the three points with label 0, + // the left side of our split should contain the four points with label 1 val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) - testHelper(values, labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0)) + testHelper(values, labels2, expectedLeftCategories = Array(0.0, 1.0), + expectedLeftStats = Array(3.0, 0.0), expectedRightStats = Array(0.0, 4.0)) } test("chooseOrderedCategoricalSplit: return bad stats if we should not split") { @@ -172,8 +176,7 @@ class TreeSplitUtilsSuite // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty)) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, - labels, featureSplits = Array.empty) + val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array.empty) // Choose split, verify that it's invalid val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex) @@ -186,6 +189,7 @@ class TreeSplitUtilsSuite // label: 0 --> values: 1 // label: 1 --> values: 0, 2 // label: 2 --> values: 2 + // Expected split: feature value 1 on the left, values (0, 2) on the right val values = Array(1, 1, 0, 2, 2) val featureArity = values.max + 1 val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) @@ -193,8 +197,7 @@ class TreeSplitUtilsSuite val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 3, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, - labels, splits) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split val (split, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex, @@ -214,7 +217,7 @@ class TreeSplitUtilsSuite } test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") { - // Construct data for unordered categorical feature + // Construct data for unordered categorical feature; all points have label 1 val featureIndex = 0 val featureArity = 4 val values = Array(3, 1, 0, 2, 2) @@ -223,8 +226,7 @@ class TreeSplitUtilsSuite val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, - splits) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split, verify that it's invalid val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex, splits) @@ -241,8 +243,7 @@ class TreeSplitUtilsSuite val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, - splits) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split, verify that it has expected threshold val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, @@ -261,7 +262,7 @@ class TreeSplitUtilsSuite } test("chooseContinuousSplit: return bad stats if we should not split") { - // Construct data for continuous feature + // Construct data for continuous feature; all points have label 0 val featureIndex = 0 val thresholds = Array(0, 1, 2, 3) val values = thresholds.indices.toArray @@ -270,10 +271,9 @@ class TreeSplitUtilsSuite val splits = TreeTests.getContinuousSplits(thresholds, featureIndex) val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 2, Map.empty[Int, Int]) - val statsAggregator = getAggregator(metadata, values, from = 0, to = values.length, labels, - splits) + val statsAggregator = getAggregator(metadata, values, labels, splits) // Choose split, verify that it's invalid - val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + val (_, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndex, splits) assert(!stats.valid) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 0f26f7bfeed6..d16ca5432cf0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -133,12 +133,12 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Returns an array of continuous splits for the feature with index featureIndex and the passed-in - * set of values. Creates one continuous split per value in values. + * set of threshold. Creates one continuous split per threshold in thresholds. */ private[impl] def getContinuousSplits( - values: Array[Int], + thresholds: Array[Int], featureIndex: Int): Array[Split] = { - val splits = values.sorted.map { + val splits = thresholds.sorted.map { new ContinuousSplit(featureIndex, _).asInstanceOf[Split] } splits From 5bcccda6a599f60d2d084ecf871649675a792d5d Mon Sep 17 00:00:00 2001 From: Sid Murching Date: Fri, 1 Dec 2017 09:50:06 -0800 Subject: [PATCH 8/8] Update TreeSplitUtilsSuite --- .../ml/tree/impl/TreeSplitUtilsSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala index 756c09f3b99a..290f6d78ef67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -186,13 +186,13 @@ class TreeSplitUtilsSuite test("chooseUnorderedCategoricalSplit: basic case") { val featureIndex = 0 // Construct data for unordered categorical feature - // label: 0 --> values: 1 - // label: 1 --> values: 0, 2 - // label: 2 --> values: 2 - // Expected split: feature value 1 on the left, values (0, 2) on the right - val values = Array(1, 1, 0, 2, 2) + // label: 0 --> values: 0, 1 + // label: 1 --> values: 2, 3 + // label: 2 --> values: 2, 2, 4 + // Expected split: feature values (0, 1) on the left, values (2, 3, 4) on the right + val values = Array(0, 1, 2, 3, 2, 2, 4) + val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) val featureArity = values.max + 1 - val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) // Construct DTStatsAggregator, compute sufficient stats val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, numClasses = 3, Map(featureIndex -> featureArity)) @@ -206,14 +206,14 @@ class TreeSplitUtilsSuite split match { case s: CategoricalSplit => assert(s.featureIndex === featureIndex) - assert(s.leftCategories.toSet === Set(1.0)) - assert(s.rightCategories.toSet === Set(0.0, 2.0)) + assert(s.leftCategories.toSet === Set(0.0, 1.0)) + assert(s.rightCategories.toSet === Set(2.0, 3.0, 4.0)) case _ => throw new AssertionError( s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") } validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0, 0.0), - expectedRightStats = Array(0.0, 2.0, 1.0)) + expectedRightStats = Array(0.0, 2.0, 3.0)) } test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") {