Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@

package org.apache.spark.ml.tree.impl

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.TreeUtil._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -89,8 +87,8 @@ private[ml] object AltDT extends Logging {
}

private[impl] object AltDTMetadata {
def fromStrategy(strategy: Strategy) = new AltDTMetadata(strategy.numClasses, strategy.maxBins,
strategy.minInfoGain, strategy.impurity)
def fromStrategy(strategy: Strategy): AltDTMetadata = new AltDTMetadata(strategy.numClasses,
strategy.maxBins, strategy.minInfoGain, strategy.impurity)
}

/**
Expand Down Expand Up @@ -144,15 +142,14 @@ private[ml] object AltDT extends Logging {
}
// Group columns together into one array of columns per partition.
// TODO: Test avoiding this grouping, and see if it matters.
val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator =>
val groupedCols = new ArrayBuffer[FeatureVector]
iterator.foreach(groupedCols += _)
if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator()
val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions {
iterator: Iterator[FeatureVector] =>
if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator()
}
groupedColStore.persist(StorageLevel.MEMORY_AND_DISK)

// Initialize partitions with 1 node (each instance at the root node).
var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols =>
var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols =>
val initActive = new BitSet(1)
initActive.set(0)
new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive)
Expand All @@ -165,16 +162,10 @@ private[ml] object AltDT extends Logging {
var activeNodePeriphery: Array[LearningNode] = Array(rootNode)
var numNodeOffsets: Int = 2

val partitionInfosDebug = new scala.collection.mutable.ArrayBuffer[RDD[PartitionInfo]]()
partitionInfosDebug.append(partitionInfosA)

// Iteratively learn, one level of the tree at a time.
var currentLevel = 0
var doneLearning = false
while (currentLevel < strategy.maxDepth && !doneLearning) {

val partitionInfos = partitionInfosDebug.last

// Compute best split for each active node.
val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] =
computeBestSplits(partitionInfos, labelsBc, metadata)
Expand Down Expand Up @@ -208,12 +199,13 @@ private[ml] object AltDT extends Logging {

// Broadcast aggregated bit vectors. On each partition, update instance--node map.
val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors)
// partitionInfos = partitionInfos.map { partitionInfo =>
val partitionInfosB = partitionInfos.map { partitionInfo =>
val newPartitionInfos = partitionInfos.map { partitionInfo =>
partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets)
}
partitionInfosB.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above...
partitionInfosDebug.append(partitionInfosB)
// TODO: remove. For some reason, this is needed to make things work.
// Probably messing up somewhere above...
newPartitionInfos.cache().count()
partitionInfos = newPartitionInfos

// TODO: unpersist aggBitVectorsBc after action.
}
Expand Down Expand Up @@ -260,7 +252,8 @@ private[ml] object AltDT extends Logging {
// for each active node, best split + info gain,
// where the best split is None if no useful split exists
val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map {
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) =>
case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int],
activeNodes: BitSet) =>
val localLabels = labelsBc.value
// Iterate over the active nodes in the current level.
activeNodes.iterator.map { nodeIndexInLevel: Int =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
def prob(label: Double): Double = -1

/** Get [[Predict]] struct. */
def getPredict = {
def getPredict: Predict = {
val pred = this.predict
new Predict(predict = pred, prob = this.prob(pred))
}
Expand Down