Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable {
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable {

var numBins: Int = Int.MinValue

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,30 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating the [[http://en.wikipedia
* .org/wiki/Decision_tree_learning#Gini_impurity]] during binary classification
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
*/
object Gini extends Impurity {

/**
* gini coefficient calculation
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return gini coefficient value
* @return Gini coefficient value
*/
def calculate(c0 : Double, c1 : Double): Double = {
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0*f0 - f1*f1
1 - f0 * f0 - f1 * f1
}
}

def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.mllib.tree.impurity

/**
* Trail for calculating information gain
* Trait for calculating information gain.
*/
trait Impurity extends Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,21 @@

package org.apache.spark.mllib.tree.impurity

import java.lang.UnsupportedOperationException

/**
* Class for calculating variance during regression
*/
object Variance extends Impurity {
def calculate(c0: Double, c1: Double): Double
= throw new UnsupportedOperationException("Variance.calculate")
override def calculate(c0: Double, c1: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum*sum)/count
squaredLoss/count
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,4 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) {

}
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,4 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
def predict(features: RDD[Array[Double]]): RDD[Double] = {
features.map(x => predict(x))
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,4 @@ class InformationGainStats(
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Node (
val split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
val stats: Option[InformationGainStats]) extends Serializable with Logging{
val stats: Option[InformationGainStats]) extends Serializable with Logging {

override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
"split = " + split + ", stats = " + stats
Expand All @@ -46,7 +46,7 @@ class Node (
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
def build(nodes : Array[Node]): Unit = {
def build(nodes: Array[Node]): Unit = {

logDebug("building node " + id + " at level " +
(scala.math.log(id + 1)/scala.math.log(2)).toInt )
Expand All @@ -68,7 +68,7 @@ class Node (
* @param feature feature value
* @return predicted value
*/
def predictIfLeaf(feature : Array[Double]) : Double = {
def predictIfLeaf(feature: Array[Double]) : Double = {
if (isLeaf) {
predict
} else{
Expand All @@ -87,5 +87,4 @@ class Node (
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ case class Split(
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType : FeatureType)
class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())

/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType : FeatureType)
class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

/**
Expand All @@ -59,6 +59,6 @@ class DummyHighSplit(feature: Int, featureType : FeatureType)
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType : FeatureType)
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

Loading