Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4f4bf8a
Wrote rowToColumnStoreDense method
jkbradley Nov 7, 2014
b8d2c93
Wrote TreeUtilSuite to test rowToColumnStoreDense
jkbradley Nov 7, 2014
100d506
scala style
jkbradley Nov 7, 2014
74d6732
Added rowToColumnStoreSparse to tree Utils, and added test suite for it.
jkbradley Nov 30, 2014
1d1d71f
small fix in TreeUtilSuite
jkbradley Dec 23, 2014
3fcdeaf
some updates before rebasing
jkbradley Jul 23, 2015
bfa1819
Done implement MVP for partitioning by feature, but need to test and …
jkbradley Jul 23, 2015
9e7b0e9
debugging AltDT, not done yet
jkbradley Jul 24, 2015
be5b237
debugging and adding unit tests
jkbradley Jul 24, 2015
3c1a4d3
AltDT failing on imbalanced tree test. Going to refactor to broadcas…
jkbradley Jul 26, 2015
15b009b
AltDT is working
jkbradley Jul 27, 2015
d58ce40
removed debug printlns
jkbradley Jul 27, 2015
40719af
fix merge conflicts
jkbradley Aug 26, 2015
5941577
modified AltDT to use ImpurityStats
jkbradley Aug 26, 2015
248d0a7
fixed conflicts with master, and started to add support for categoric…
jkbradley Aug 26, 2015
0764b90
Added computation of unordered features
jkbradley Aug 27, 2015
ceaf5f5
refactored to prepare for adding categorical feature support
jkbradley Aug 27, 2015
675951e
Added support for ordered categorical features, but not tested yet
jkbradley Sep 16, 2015
1855176
Added a few unit tests
jkbradley Sep 24, 2015
d897714
Changes BitSubvector to use System.arraycopy
feynmanliang Sep 29, 2015
6cc1ed8
Removes commented code
feynmanliang Sep 29, 2015
71fd54c
Fixes nested tests
feynmanliang Sep 29, 2015
1597d11
Makes BitSubvectors left-align on word boundaries and adds comments
feynmanliang Oct 3, 2015
92a6fa5
Changes copyFrom to orWithOffset
feynmanliang Oct 3, 2015
506ac10
Adds dumb impl and failing tests for offset!=0
feynmanliang Oct 3, 2015
2c329df
Pass unit tests, adds debug code
feynmanliang Oct 3, 2015
9a642b8
Removes debug printlns
feynmanliang Oct 3, 2015
49a628a
Fixes implementation, cleans comments, more tests
feynmanliang Oct 3, 2015
c5e5480
Improves docs and fixes Long.MaxValue bug
feynmanliang Oct 10, 2015
1213f86
Cleans up shared state in BitSetSuite
feynmanliang Oct 10, 2015
8f1360f
Cleans up shared state in BitSubvectorSuite
feynmanliang Oct 10, 2015
3c5060d
Loops through offsets for orWithOffset test
feynmanliang Oct 10, 2015
2db78ee
Updates test names
feynmanliang Oct 10, 2015
7072f91
Adds tests for 0 until offset
feynmanliang Nov 14, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,37 @@ class BitSet(numBits: Int) extends Serializable {

/** Return the number of longs it would take to hold numBits. */
private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1

/**
* Bit-wise OR between two BitSets where the ith bit of other is ORed against the i+offset bit of this instance. For
* performance, the OR is computed word-by-word rather than bit-by-bit.
*
* This function mutates the current BitSet instance (i.e. not `other`).
*
* @param offset the amount to left-shift (with zero padding) `other` before performing the OR, must be >= 0.
*/
private[spark] def orWithOffset(other: BitSet, offset: Int): Unit = {
val numWords = bit2words(math.min(this.capacity, other.capacity - offset))
val wordOffset = offset >> 6 // divide by 64

// Bit vectors have memory layout [63..0|127..64|...] where | denotes word boundaries, so left/right within a word
// and left/right across words are flipped
val rightOffset = offset & 0x3f // mod 64
val leftOffset = (64 - rightOffset) & 0x3f // mod 64

var wordIndex = 0
while (wordIndex < numWords) {
// Fill in lowest-order bits from other's previous word's highest-order bits if available
if (rightOffset > 0 && wordIndex > 0) {
val maskedShiftedPrevWord = (other.words(wordIndex - 1) & (-1L << leftOffset)) >> leftOffset
words(wordIndex + wordOffset) = words(wordIndex + wordOffset) | maskedShiftedPrevWord
}

// Mask, shift, and OR with current word
val maskedShiftedOtherWord = (other.words(wordIndex) & (-1L >> rightOffset)) << rightOffset
words(wordIndex + wordOffset) = words(wordIndex + wordOffset) | maskedShiftedOtherWord

wordIndex += 1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,29 @@ class BitSetSuite extends SparkFunSuite {
assert(bitset.cardinality() === setBits.size)
}

test("orWithOffset") {
val setBits = Seq(0, 9, 1, 10, 90, 96)
val bitset = new BitSet(100)
setBits.foreach(i => bitset.set(i))

for {
offset <- Seq(0, 1, 63, 64, 65)
} {
val copyBitset = new BitSet(100)
copyBitset.orWithOffset(bitset, offset)
for (i <- 0 until offset) {
assert(!copyBitset.get(i))
}
for (i <- offset until 100) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could check for elements 0 until offset being set to 0

if (setBits.contains(i - offset)) {
assert(copyBitset.get(i))
} else {
assert(!copyBitset.get(i))
}
}
}
}

test("100% full bit set") {
val bitset = new BitSet(10000)
for (i <- 0 until 10000) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object DecisionTreeExample {
testInput: String = "",
dataFormat: String = "libsvm",
algo: String = "Classification",
algorithm: String = "byRow",
maxDepth: Int = 5,
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
Expand All @@ -77,6 +78,9 @@ object DecisionTreeExample {
opt[String]("algo")
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[String]("algorithm")
.text(s"algorithm (byRow, byCol), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algorithm = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
Expand Down Expand Up @@ -236,33 +240,37 @@ object DecisionTreeExample {
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
/*
val featuresIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
*/
// (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier()
.setFeaturesCol("indexedFeatures")
.setFeaturesCol("features") // indexedFeatures
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
.setMinInstancesPerNode(params.minInstancesPerNode)
.setMinInfoGain(params.minInfoGain)
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
.setAlgorithm(params.algorithm)
case "regression" =>
new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setFeaturesCol("features") // indexedFeatures
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
.setMinInstancesPerNode(params.minInstancesPerNode)
.setMinInfoGain(params.minInfoGain)
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
.setAlgorithm(params.algorithm)
case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
stages += dt
Expand All @@ -278,14 +286,14 @@ object DecisionTreeExample {
algo match {
case "classification" =>
val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
if (treeModel.numNodes < 200) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
if (treeModel.numNodes < 200) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.tree.impl.{AltDT, RandomForest}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -62,6 +62,25 @@ final class DecisionTreeClassifier(override val uid: String)

override def setImpurity(value: String): this.type = super.setImpurity(value)

/**
* Algorithm used for learning.
* Supported: "byRow" or "byCol" (case sensitive).
* (default = "byRow")
* @group param
*/
val algorithm: Param[String] = new Param[String](this, "algorithm", "Algorithm used " +
"for learning. Supported options:" +
s" ${DecisionTreeClassifier.supportedAlgorithms.mkString(", ")}",
(value: String) => DecisionTreeClassifier.supportedAlgorithms.contains(value))

setDefault(algorithm -> "byRow")

/** @group setParam */
def setAlgorithm(value: String): this.type = set(algorithm, value)

/** @group getParam */
def getAlgorithm: String = $(algorithm)

override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Expand All @@ -74,9 +93,15 @@ final class DecisionTreeClassifier(override val uid: String)
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
val model = getAlgorithm match {
case "byRow" =>
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, parentUID = Some(uid))
trees.head
case "byCol" =>
AltDT.train(oldDataset, strategy, parentUID = Some(uid))
}
model.asInstanceOf[DecisionTreeClassificationModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
Expand All @@ -94,6 +119,8 @@ final class DecisionTreeClassifier(override val uid: String)
object DecisionTreeClassifier {
/** Accessor for supported impurities: entropy, gini */
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities

final val supportedAlgorithms: Array[String] = Array("byRow", "byCol")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.ml.regression

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.tree.impl.{AltDT, RandomForest}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -62,14 +62,39 @@ final class DecisionTreeRegressor(override val uid: String)

override def setImpurity(value: String): this.type = super.setImpurity(value)

/**
* Algorithm used for learning.
* Supported: "byRow" or "byCol" (case sensitive).
* (default = "byRow")
* @group param
*/
val algorithm: Param[String] = new Param[String](this, "algorithm", "Algorithm used " +
"for learning. Supported options:" +
s" ${DecisionTreeRegressor.supportedAlgorithms.mkString(", ")}",
(value: String) => DecisionTreeRegressor.supportedAlgorithms.contains(value))

setDefault(algorithm -> "byRow")

/** @group setParam */
def setAlgorithm(value: String): this.type = set(algorithm, value)

/** @group getParam */
def getAlgorithm: String = $(algorithm)

override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
val model = getAlgorithm match {
case "byRow" =>
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1,
featureSubsetStrategy = "all", seed = 0L, parentUID = Some(uid))
trees.head
case "byCol" =>
AltDT.train(oldDataset, strategy, parentUID = Some(uid))
}
model.asInstanceOf[DecisionTreeRegressionModel]
}

/** (private[ml]) Create a Strategy instance to use with the old API. */
Expand All @@ -85,6 +110,8 @@ final class DecisionTreeRegressor(override val uid: String)
object DecisionTreeRegressor {
/** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities

final val supportedAlgorithms: Array[String] = Array("byRow", "byCol")
}

/**
Expand Down
2 changes: 1 addition & 1 deletion mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ private[tree] object LearningNode {
id: Int,
isLeaf: Boolean,
stats: ImpurityStats): LearningNode = {
new LearningNode(id, None, None, None, false, stats)
new LearningNode(id, None, None, None, isLeaf, stats)
}

/** Create an empty node with the given node index. Values must be set later on. */
Expand Down
18 changes: 18 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ sealed trait Split extends Serializable {
*/
private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean

/**
* Return true (split to left) or false (split to right).
* @param feature Feature value (original value, not binned)
*/
private[tree] def shouldGoLeft(feature: Double): Boolean

/** Convert to old Split format */
private[tree] def toOld: OldSplit
}
Expand Down Expand Up @@ -112,6 +118,14 @@ final class CategoricalSplit private[ml] (
}
}

override private[tree] def shouldGoLeft(feature: Double): Boolean = {
if (isLeft) {
categories.contains(feature)
} else {
!categories.contains(feature)
}
}

override def equals(o: Any): Boolean = {
o match {
case other: CategoricalSplit => featureIndex == other.featureIndex &&
Expand Down Expand Up @@ -172,6 +186,10 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
}
}

override private[tree] def shouldGoLeft(feature: Double): Boolean = {
feature <= threshold
}

override def equals(o: Any): Boolean = {
o match {
case other: ContinuousSplit =>
Expand Down
Loading