Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions.{lit, col}

/**
* :: Experimental ::
Expand All @@ -39,7 +40,7 @@ import org.apache.spark.sql.DataFrame
@Experimental
final class DecisionTreeClassifier(override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
with DecisionTreeParams with TreeClassifierParams with HasWeightCol {

def this() = this(Identifiable.randomUID("dtc"))

Expand All @@ -62,6 +63,15 @@ final class DecisionTreeClassifier(override val uid: String)

override def setImpurity(value: String): this.type = super.setImpurity(value)

/**
* Whether to over-/under-sample training instances according to the given weights in weightCol.
* If empty, all instances are treated equally (weight 1.0).
* Default is empty, so all instances have weight one.
* @group setParam
*/
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")

override def setSeed(value: Long): this.type = super.setSeed(value)

override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
Expand All @@ -74,7 +84,11 @@ final class DecisionTreeClassifier(override val uid: String)
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions.{col, lit, udf}


/**
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestClassifier(override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
with RandomForestParams with TreeClassifierParams with HasWeightCol{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Space after HasWeightCol


def this() = this(Identifiable.randomUID("rfc"))

Expand Down Expand Up @@ -79,6 +79,15 @@ final class RandomForestClassifier(override val uid: String)
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)

/**
* Whether to over-/under-sample training instances according to the given weights in weightCol.
* If empty, all instances are treated equally (weight 1.0).
* Default is empty, so all instances have weight one.
* @group setParam
*/
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")

override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
Expand All @@ -89,7 +98,12 @@ final class RandomForestClassifier(override val uid: String)
" specified. See StringIndexer.")
// TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)

val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val trees =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@

package org.apache.spark.ml.regression


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove extra line

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions.{lit, col}

/**
* :: Experimental ::
Expand All @@ -40,7 +42,7 @@ import org.apache.spark.sql.DataFrame
@Experimental
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeParams with TreeRegressorParams {
with DecisionTreeParams with TreeRegressorParams with HasWeightCol{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after HasWeightCol


@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtr"))
Expand Down Expand Up @@ -73,10 +75,23 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val

override def setSeed(value: Long): this.type = super.setSeed(value)

/**
* Whether to over-/under-sample training instances according to the given weights in weightCol.
* If empty, all instances are treated equally (weight 1.0).
* Default is empty, so all instances have weight one.
* @group setParam
*/
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")

override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy = getOldStrategy(categoricalFeatures)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.spark.ml.regression

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.functions.{col, lit, udf}


/**
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestParams with TreeRegressorParams {
with RandomForestParams with TreeRegressorParams with HasWeightCol{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after HasWeightCol


@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
Expand Down Expand Up @@ -89,10 +89,23 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)

/**
* Whether to over-/under-sample training instances according to the given weights in weightCol.
* If empty, all instances are treated equally (weight 1.0).
* Default is empty, so all instances have weight one.
* @group setParam
*/
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")

override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
val trees =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.util.Random

import org.apache.spark.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.linalg.{Vectors, Vector}
Expand All @@ -43,11 +44,11 @@ private[ml] object RandomForest extends Logging {

/**
* Train a random forest.
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param input Training data: RDD of [[org.apache.spark.ml.feature.Instance]]
* @return an unweighted set of trees
*/
def run(
input: RDD[LabeledPoint],
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
Expand All @@ -60,9 +61,8 @@ private[ml] object RandomForest extends Logging {

timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
val retaggedInput = input.retag(classOf[Instance])
val metadata = buildWeightedMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
logDebug("algo = " + strategy.algo)
logDebug("numTrees = " + numTrees)
logDebug("seed = " + seed)
Expand All @@ -87,8 +87,10 @@ private[ml] object RandomForest extends Logging {

val withReplacement = numTrees > 1

val baggedInput = BaggedPoint
val unadjustedBaggedInput = BaggedPoint
.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)

val baggedInput = reweightSubSampleWeights(unadjustedBaggedInput)
.persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
Expand Down Expand Up @@ -823,7 +825,7 @@ private[ml] object RandomForest extends Logging {
* of size (numFeatures, numBins).
*/
protected[tree] def findSplits(
input: RDD[LabeledPoint],
input: RDD[Instance],
metadata: DecisionTreeMetadata,
seed : Long): Array[Array[Split]] = {

Expand All @@ -844,7 +846,7 @@ private[ml] object RandomForest extends Logging {
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
} else {
new Array[LabeledPoint](0)
new Array[Instance](0)
}

val splits = new Array[Array[Split]](numFeatures)
Expand Down Expand Up @@ -1171,4 +1173,28 @@ private[ml] object RandomForest extends Logging {
}
}

/**
* Inject the sample weight to sub-sample weights of the baggedPoints
*/
private[impl] def reweightSubSampleWeights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a TODO in BaggedPoint.scala for accepting weighted instances. This might be a good time to address that. If not, we will have to implement this in this JIRA, fix Bagged Point in another JIRA, and then return to this, likely in a third JIRA. Thoughts?

baggedTreePoints: RDD[BaggedPoint[TreePoint]]): RDD[BaggedPoint[TreePoint]] = {
baggedTreePoints.map {bagged =>
val treePoint = bagged.datum
val adjustedSubSampleWeights = bagged.subsampleWeights.map(w => w * treePoint.weight)
new BaggedPoint[TreePoint](treePoint, adjustedSubSampleWeights)
}
}

/**
* A thin adaptor to [[org.apache.spark.mllib.tree.impl.DecisionTreeMetadata.buildMetadata]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "adaptor" -> "adapter". I'm not sure it is currently incorrect but "adapter" is the significantly more common spelling.

*/
private[impl] def buildWeightedMetadata(
input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
val unweightedInput = input.map {w => LabeledPoint(w.label, w.features)}
DecisionTreeMetadata.buildMetadata(unweightedInput, strategy, numTrees, featureSubsetStrategy)
}
}

Loading