Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,9 @@ private[ml] object RandomForest extends Logging {
featureIndexIdx += 1
}
} else {
// Use all features
val numFeatures = agg.metadata.numFeatures
var featureIndex = 0
while (featureIndex < numFeatures) {
agg.metadata.featureIndexes.foreach{featureIndex =>
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
}
Expand Down Expand Up @@ -829,10 +825,8 @@ private[ml] object RandomForest extends Logging {

logDebug("isMulticlass = " + metadata.isMulticlass)

val numFeatures = metadata.numFeatures

// Sample the input only if there are continuous features.
val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
val hasContinuousFeatures = metadata.featureIndexes.exists(metadata.isContinuous)
val sampledInput = if (hasContinuousFeatures) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
Expand All @@ -847,12 +841,11 @@ private[ml] object RandomForest extends Logging {
new Array[LabeledPoint](0)
}

val splits = new Array[Array[Split]](numFeatures)
val splits = new Array[Array[Split]](metadata.numFeatures)

// Find all splits.
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
metadata.featureIndexes.foreach{featureIndex =>
if (metadata.isContinuous(featureIndex)) {
val featureSamples = sampledInput.map(_.features(featureIndex))
val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
Expand Down Expand Up @@ -893,7 +886,6 @@ private[ml] object RandomForest extends Logging {
splits(featureIndex) = new Array[Split](0)
}
}
featureIndex += 1
}
splits
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,21 +347,8 @@ object DecisionTree extends Serializable with Logging {
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
featuresForNode.get.size
} else {
// Use all features
agg.metadata.numFeatures
}
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
val features: Array[Int] = featuresForNode.getOrElse(agg.metadata.featureIndexes.toArray)
features.zipWithIndex.foreach{case (featureIndex, featureIndexIdx) =>
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
Expand All @@ -385,7 +372,6 @@ object DecisionTree extends Serializable with Logging {
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ private[spark] class DTStatsAggregator(
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
println("Updating for "+featureIndex+"binIndex"+binIndex)
val i = featureOffsets(featureIndex) + binIndex * statsSize
println("featureOffsets are "+featureOffsets.toList)
println("featureOffsets is "+featureOffsets(featureIndex))
println("i "+i)
impurityAggregator.update(allStats, i, label, instanceWeight)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ import org.apache.spark.rdd.RDD
/**
* Learning and dataset metadata for DecisionTree.
*
* @param numFeatures Total number of features (including single class)
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
* @param featureIndexes Indexes of usable (e.g non-single-class) features.
*/
private[spark] class DecisionTreeMetadata(
val numFeatures: Int,
Expand All @@ -51,7 +53,8 @@ private[spark] class DecisionTreeMetadata(
val minInstancesPerNode: Int,
val minInfoGain: Double,
val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable {
val numFeaturesPerNode: Int,
val featureIndexes: IndexedSeq[Int]) extends Serializable {

def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)

Expand Down Expand Up @@ -111,6 +114,14 @@ private[spark] object DecisionTreeMetadata extends Logging {
throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
s"but was given by empty one.")
}
// Construct the feature indexes that we can use (one category features are not useful)
val featureIndexes = if (strategy.categoricalFeaturesInfo.nonEmpty) {
val singleCategoryIndexes = strategy.categoricalFeaturesInfo.filter(_._2 < 2).map(_._1).toSet
0.to(numFeatures-1).filterNot(singleCategoryIndexes.contains)
} else {
0.to(numFeatures-1)
}
val numActiveFeatures = featureIndexes.size
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
Expand Down Expand Up @@ -144,8 +155,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
val maxCategoriesForUnorderedFeature =
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
// Hack: If a categorical feature has only 1 category, we treat it as continuous.
// TODO(SPARK-9957): Handle this properly by filtering out those features.
// Set number of bins to -1 if we are skipping a feature
if (numCategories > 1) {
// Decide if some categorical features should be treated as unordered features,
// which require 2 * ((1 << numCategories - 1) - 1) bins.
Expand All @@ -157,14 +167,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
} else {
numBins(featureIndex) = numCategories
}
} else {
numBins(featureIndex) = 0
}
}
} else {
// Binary classification or regression
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
// If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
// Set number of bins to 0 if we are skipping a feature
if (numCategories > 1) {
numBins(featureIndex) = numCategories
} else {
numBins(featureIndex) = 0
}
}
}
Expand All @@ -184,16 +198,17 @@ private[spark] object DecisionTreeMetadata extends Logging {
case _ => featureSubsetStrategy
}
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
case "all" => numActiveFeatures
case "sqrt" => math.sqrt(numActiveFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numActiveFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numActiveFeatures / 3.0).ceil.toInt
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode,
featureIndexes)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, 0.to(0).toArray
)
val featureSamples = Array.fill(200000)(math.random)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -130,7 +130,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, 0.to(0).toArray
)
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -144,7 +144,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, 0.to(0).toArray
)
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand All @@ -158,7 +158,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
0, 0, 0.0, 0, 0
0, 0, 0.0, 0, 0, 0.to(0).toArray
)
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Make sure trees are the same.
assert(rfTree.toString == dt.toString)
}

/*
test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
Expand Down Expand Up @@ -196,7 +196,32 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
}
*/
test("filtering of 1 category categorical point") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0, 0.0, 3.0, 1.0))
arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0))
arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0, 0.0, 6.0, 3.0))
arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
val categoricalFeaturesInfo = Map(0 -> 1, 2 -> 2, 4 -> 4)
val input = sc.parallelize(arr)

val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
// TODO(holden): go through and make sure that none of the trees have the 0 feature used in them.
def assertTreeDoesNotContain(node: Node, feature: Long): Unit = {
node.split.foreach(split => assert(split.feature != feature))
node.leftNode.foreach(assertTreeDoesNotContain(_, feature))
node.rightNode.foreach(assertTreeDoesNotContain(_, feature))
}
model.trees.foreach{tree =>

}
}

/*
test("subsampling rate in RandomForest"){
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
val rdd = sc.parallelize(arr)
Expand Down Expand Up @@ -233,5 +258,5 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}

*/
}