Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,8 @@ private[tree] class LearningNode(
// Here we want to keep same behavior with the old mllib.DecisionTreeModel
new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
}

}
}

}

private[tree] object LearningNode {
Expand All @@ -292,8 +290,8 @@ private[tree] object LearningNode {
}

/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
new LearningNode(nodeIndex, None, None, None, false, null)
def emptyNode(id: Int): LearningNode = {
new LearningNode(id, None, None, None, false, null)
}

// The below indexing methods were copied from spark.mllib.tree.model.Node
Expand Down
16 changes: 7 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ private[ml] object AltDT extends Logging {
// rather than 1 copy per worker. This means a lot of random accesses.
// We could improve this by applying first-level sorting (by node) to labels.

// TODO: RIGHT HERE NOW: JUST ADDED ISUNORDERED

// Sort each column by feature values.
val colStore: RDD[FeatureVector] = colStoreInit.map { case (featureIndex: Int, col: Vector) =>
val featureArity: Int = strategy.categoricalFeaturesInfo.getOrElse(featureIndex, 0)
Expand Down Expand Up @@ -293,9 +291,11 @@ private[ml] object AltDT extends Logging {
* On driver: Grow tree based on chosen splits, and compute new set of active nodes.
* @param oldPeriphery Old periphery of active nodes.
* @param bestSplitsAndGains Best (split, gain) pairs, which can be zipped with the old
* periphery.
* periphery. These stats will be used to replace the stats in
* any nodes which are split.
* @param minInfoGain Threshold for min info gain required to split a node.
* @return New active node periphery
* @return New active node periphery.
* If a node is split, then this method will update its fields.
*/
private[impl] def computeActiveNodePeriphery(
oldPeriphery: Array[LearningNode],
Expand Down Expand Up @@ -482,12 +482,13 @@ private[ml] object AltDT extends Logging {

var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid
val bestLeftImpurityAgg = leftImpurityAgg.deepCopy()
var bestGain: Double = -1.0
var bestGain: Double = 0.0
val fullImpurity = rightImpurityAgg.getCalculator.calculate()
var leftCount: Double = 0.0
var rightCount: Double = rightImpurityAgg.getCount
val fullCount: Double = rightCount

// Consider all splits. These only cover valid splits, with at least one category on each side.
val numSplits = categoriesSortedByCentroid.length - 1
var sortedCatIndex = 0
while (sortedCatIndex < numSplits) {
Expand All @@ -512,9 +513,6 @@ private[ml] object AltDT extends Logging {
sortedCatIndex += 1
}

assert(bestSplitIndex != -1, "Unknown error in AltDT split selection for ordered categorical" +
s" variable with numSplits = $numSplits.")

val categoriesForSplit =
categoriesSortedByCentroid.slice(0, bestSplitIndex + 1).map(_.toDouble)
val bestFeatureSplit =
Expand All @@ -524,7 +522,7 @@ private[ml] object AltDT extends Logging {
val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator,
bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator)

if (bestSplitIndex == 0 || bestSplitIndex == categoriesSortedByCentroid.length - 1) {
if (bestSplitIndex == -1 || bestGain == 0.0) {
(None, bestImpurityStats)
} else {
(Some(bestFeatureSplit), bestImpurityStats)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ private[ml] object RandomForest extends Logging {
rng.setSeed(seed)

// Allocate and queue root nodes.
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(nodeIndex = 1))
val topNodes = Array.fill[LearningNode](numTrees)(LearningNode.emptyNode(id = 1))
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

while (nodeQueue.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,8 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal

override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])"

private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
case o: EntropyCalculator => stats.sameElements(other.stats)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,8 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul

override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])"

private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
case o: GiniCalculator => stats.sameElements(other.stats)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,6 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
result._1
}

/** Test exact equality */
private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,8 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})"
}

private[spark] def exactlyEquals(other: ImpurityCalculator): Boolean = other match {
case o: VarianceCalculator => stats.sameElements(other.stats)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ private[spark] object InformationGainStats {
* @param rightImpurityCalculator impurity statistics for right child node
* @param valid whether the current split satisfies minimum info gain or
* minimum number of instances per node
* TODO: Can we remove this? Not sure if this is used anywhere...
*/
@DeveloperApi
private[spark] class ImpurityStats(
Expand All @@ -114,6 +115,15 @@ private[spark] class ImpurityStats(
} else {
-1.0
}

/** Test exact equality */
private[spark] def exactlyEquals(other: ImpurityStats): Boolean = {
gain == other.gain && impurity == other.impurity &&
impurityCalculator.exactlyEquals(other.impurityCalculator) &&
leftImpurityCalculator.exactlyEquals(other.leftImpurityCalculator) &&
rightImpurityCalculator.exactlyEquals(other.rightImpurityCalculator) &&
valid == other.valid
}
}

private[spark] object ImpurityStats {
Expand Down
162 changes: 151 additions & 11 deletions mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.ml.tree.impl

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.ml.tree.{LeafNode, InternalNode, ContinuousSplit}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo}
import org.apache.spark.ml.tree.impl.TreeUtil._
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.BitSet

Expand All @@ -44,7 +45,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxDepth(10)
.setAlgorithm("byCol")
val model = dt.fit(df)
println(model.toDebugString) // TODO: remove println
assert(model.rootNode.isInstanceOf[InternalNode])
val root = model.rootNode.asInstanceOf[InternalNode]
assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[LeafNode])
Expand All @@ -61,7 +61,6 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
.setMaxDepth(10)
.setAlgorithm("byCol")
val model = dt.fit(df)
println(model.toDebugString) // TODO: remove println
assert(model.rootNode.isInstanceOf[InternalNode])
val root = model.rootNode.asInstanceOf[InternalNode]
assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode])
Expand Down Expand Up @@ -147,18 +146,94 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
//////////////////////////////// Choosing splits //////////////////////////////////

test("computeBestSplits") {
// TODO
}

test("chooseSplit") {
test("chooseSplit: choose correct type of split") {
val labels = Seq(0, 0, 0, 1, 1, 1, 1).map(_.toDouble).toArray
val fromOffset = 1
val toOffset = 4
val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)

val col1 = FeatureVector.fromOriginal(featureIndex = 0, featureArity = 0,
featureVector = Vectors.dense(0.8, 0.1, 0.1, 0.2, 0.3, 0.5, 0.6))
val (split1, _) = AltDT.chooseSplit(col1, labels, fromOffset, toOffset, metadata)
assert(split1.nonEmpty && split1.get.isInstanceOf[ContinuousSplit])

val col2 = FeatureVector.fromOriginal(featureIndex = 0, featureArity = 3,
featureVector = Vectors.dense(0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0))
val (split2, _) = AltDT.chooseSplit(col2, labels, fromOffset, toOffset, metadata)
assert(split2.nonEmpty && split2.get.isInstanceOf[CategoricalSplit])
}

test("chooseOrderedCategoricalSplit: basic case") {
val featureIndex = 0
val values = Seq(0, 0, 1, 2, 2, 2, 2).map(_.toDouble)
val featureArity = values.max.toInt + 1

def testHelper(
labels: Seq[Double],
expectedLeftCategories: Array[Double],
expectedLeftStats: Array[Double],
expectedRightStats: Array[Double]): Unit = {
val expectedRightCategories = Range(0, featureArity)
.filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray
val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) =
AltDT.chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
split match {
case Some(s: CategoricalSplit) =>
assert(s.featureIndex === featureIndex)
assert(s.leftCategories === expectedLeftCategories)
assert(s.rightCategories === expectedRightCategories)
case _ =>
throw new AssertionError(
s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}")
}
val fullImpurityStatsArray =
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
assert(stats.gain === fullImpurity)
assert(stats.impurity === fullImpurity)
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
assert(stats.rightImpurityCalculator.stats === expectedRightStats)
assert(stats.valid)
}

val labels1 = Seq(0, 0, 1, 1, 1, 1, 1).map(_.toDouble)
testHelper(labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0))

val labels2 = Seq(0, 0, 0, 1, 1, 1, 1).map(_.toDouble)
testHelper(labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0))
}

test("chooseOrderedCategoricalSplit: return bad split if best split is on end") {
test("chooseOrderedCategoricalSplit: return bad split if we should not split") {
val featureIndex = 0
val values = Seq(0, 0, 1, 2, 2, 2, 2).map(_.toDouble)
val featureArity = values.max.toInt + 1

val labels = Seq(1, 1, 1, 1, 1, 1, 1).map(_.toDouble)

val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) =
AltDT.chooseOrderedCategoricalSplit(featureIndex, values, labels, metadata, featureArity)
assert(split.isEmpty)
val fullImpurityStatsArray =
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
assert(stats.gain === 0.0)
assert(stats.impurity === fullImpurity)
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
assert(stats.valid)
}

// test("chooseUnorderedCategoricalSplit") { }
// test("chooseUnorderedCategoricalSplit: basic case") { }

// test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { }

test("chooseContinuousSplit: basic case") {
val featureIndex = 0
Expand All @@ -175,7 +250,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
throw new AssertionError(
s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}")
}
val fullImpurityStatsArray = Array(2.0, 3.0)
val fullImpurityStatsArray =
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
assert(stats.gain === fullImpurity)
assert(stats.impurity === fullImpurity)
Expand All @@ -185,8 +261,23 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(stats.valid)
}

// TODO: Add this test once we make this change.
// test("chooseContinuousSplit: return bad split if best split is on end") { }
test("chooseContinuousSplit: return bad split if we should not split") {
val featureIndex = 0
val values = Seq(0.1, 0.2, 0.3, 0.4, 0.5)
val labels = Seq(0.0, 0.0, 0.0, 0.0, 0.0)
val impurity = Entropy
val metadata = new AltDTMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity)
val (split, stats) = AltDT.chooseContinuousSplit(featureIndex, values, labels, metadata)
// split should be None
assert(split.isEmpty)
// stats for parent node should be correct
val fullImpurityStatsArray =
Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble)
val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length)
assert(stats.gain === 0.0)
assert(stats.impurity === fullImpurity)
assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
}

//////////////////////////////// Bit subvectors //////////////////////////////////

Expand Down Expand Up @@ -258,6 +349,55 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
//////////////////////////////// Active nodes //////////////////////////////////

test("computeActiveNodePeriphery") {
// old periphery: 2 nodes
val left = LearningNode.emptyNode(id = 1)
val right = LearningNode.emptyNode(id = 2)
val oldPeriphery: Array[LearningNode] = Array(left, right)
// bestSplitsAndGains: Do not split left, but split right node.
val lCalc = new EntropyCalculator(Array(8.0, 1.0))
val lStats = new ImpurityStats(0.0, lCalc.calculate(),
lCalc, lCalc, new EntropyCalculator(Array(0.0, 0.0)))

val rSplit = new ContinuousSplit(featureIndex = 1, threshold = 0.6)
val rCalc = new EntropyCalculator(Array(5.0, 7.0))
val rRightChildCalc = new EntropyCalculator(Array(1.0, 5.0))
val rLeftChildCalc = new EntropyCalculator(Array(
rCalc.stats(0) - rRightChildCalc.stats(0),
rCalc.stats(1) - rRightChildCalc.stats(1)))
val rGain = {
val rightWeight = rRightChildCalc.stats.sum / rCalc.stats.sum
val leftWeight = rLeftChildCalc.stats.sum / rCalc.stats.sum
rCalc.calculate() -
rightWeight * rRightChildCalc.calculate() - leftWeight * rLeftChildCalc.calculate()
}
val rStats =
new ImpurityStats(rGain, rCalc.calculate(), rCalc, rLeftChildCalc, rRightChildCalc)

val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
Array((None, lStats), (Some(rSplit), rStats))

// Test A: Split right node
val newPeriphery1: Array[LearningNode] =
AltDT.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 0.0)
// Expect 2 active nodes
assert(newPeriphery1.length === 2)
// Confirm right node was updated
assert(right.split.get === rSplit)
assert(!right.isLeaf)
assert(right.stats.exactlyEquals(rStats))
assert(right.leftChild.nonEmpty && right.leftChild.get === newPeriphery1(0))
assert(right.rightChild.nonEmpty && right.rightChild.get === newPeriphery1(1))
// Confirm new active nodes have stats but no children
assert(newPeriphery1(0).leftChild.isEmpty && newPeriphery1(0).rightChild.isEmpty &&
newPeriphery1(0).split.isEmpty &&
newPeriphery1(0).stats.impurityCalculator.exactlyEquals(rLeftChildCalc))
assert(newPeriphery1(1).leftChild.isEmpty && newPeriphery1(1).rightChild.isEmpty &&
newPeriphery1(1).split.isEmpty &&
newPeriphery1(1).stats.impurityCalculator.exactlyEquals(rRightChildCalc))

// Test B: Increase minInfoGain, so split nothing
val newPeriphery2: Array[LearningNode] =
AltDT.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 1000.0)
assert(newPeriphery2.isEmpty)
}

}