-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-9478] [ml] Add class weights to Random Forest #9008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,18 +17,20 @@ | |
|
|
||
| package org.apache.spark.ml.regression | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove extra line |
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.{PredictionModel, Predictor} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} | ||
| import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.{Row, DataFrame} | ||
| import org.apache.spark.sql.functions.{lit, col} | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
|
|
@@ -40,7 +42,7 @@ import org.apache.spark.sql.DataFrame | |
| @Experimental | ||
| final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] | ||
| with DecisionTreeParams with TreeRegressorParams { | ||
| with DecisionTreeParams with TreeRegressorParams with HasWeightCol{ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Space after |
||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("dtr")) | ||
|
|
@@ -73,10 +75,23 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val | |
|
|
||
| override def setSeed(value: Long): this.type = super.setSeed(value) | ||
|
|
||
| /** | ||
| * Whether to over-/under-sample training instances according to the given weights in weightCol. | ||
| * If empty, all instances are treated equally (weight 1.0). | ||
| * Default is empty, so all instances have weight one. | ||
| * @group setParam | ||
| */ | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
| setDefault(weightCol -> "") | ||
|
|
||
| override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
| val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } | ||
| val strategy = getOldStrategy(categoricalFeatures) | ||
| val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", | ||
| seed = $(seed), parentUID = Some(uid)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,18 +18,18 @@ | |
| package org.apache.spark.ml.regression | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.param.shared.HasWeightCol | ||
| import org.apache.spark.ml.{PredictionModel, Predictor} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} | ||
| import org.apache.spark.ml.tree.impl.RandomForest | ||
| import org.apache.spark.ml.util.{Identifiable, MetadataUtils} | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} | ||
| import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.DataFrame | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.{Row, DataFrame} | ||
| import org.apache.spark.sql.functions.{col, lit, udf} | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -41,7 +41,7 @@ import org.apache.spark.sql.functions._ | |
| @Experimental | ||
| final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] | ||
| with RandomForestParams with TreeRegressorParams { | ||
| with RandomForestParams with TreeRegressorParams with HasWeightCol{ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Space after |
||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("rfr")) | ||
|
|
@@ -89,10 +89,23 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val | |
| override def setFeatureSubsetStrategy(value: String): this.type = | ||
| super.setFeatureSubsetStrategy(value) | ||
|
|
||
| /** | ||
| * Whether to over-/under-sample training instances according to the given weights in weightCol. | ||
| * If empty, all instances are treated equally (weight 1.0). | ||
| * Default is empty, so all instances have weight one. | ||
| * @group setParam | ||
| */ | ||
| def setWeightCol(value: String): this.type = set(weightCol, value) | ||
| setDefault(weightCol -> "") | ||
|
|
||
| override protected def train(dataset: DataFrame): RandomForestRegressionModel = { | ||
| val categoricalFeatures: Map[Int, Int] = | ||
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
| val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) | ||
| val oldDataset = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } | ||
| val strategy = | ||
| super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) | ||
| val trees = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import scala.util.Random | |
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.ml.classification.DecisionTreeClassificationModel | ||
| import org.apache.spark.ml.feature.Instance | ||
| import org.apache.spark.ml.regression.DecisionTreeRegressionModel | ||
| import org.apache.spark.ml.tree._ | ||
| import org.apache.spark.mllib.linalg.{Vectors, Vector} | ||
|
|
@@ -43,11 +44,11 @@ private[ml] object RandomForest extends Logging { | |
|
|
||
| /** | ||
| * Train a random forest. | ||
| * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] | ||
| * @param input Training data: RDD of [[org.apache.spark.ml.feature.Instance]] | ||
| * @return an unweighted set of trees | ||
| */ | ||
| def run( | ||
| input: RDD[LabeledPoint], | ||
| input: RDD[Instance], | ||
| strategy: OldStrategy, | ||
| numTrees: Int, | ||
| featureSubsetStrategy: String, | ||
|
|
@@ -60,9 +61,8 @@ private[ml] object RandomForest extends Logging { | |
|
|
||
| timer.start("init") | ||
|
|
||
| val retaggedInput = input.retag(classOf[LabeledPoint]) | ||
| val metadata = | ||
| DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) | ||
| val retaggedInput = input.retag(classOf[Instance]) | ||
| val metadata = buildWeightedMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) | ||
| logDebug("algo = " + strategy.algo) | ||
| logDebug("numTrees = " + numTrees) | ||
| logDebug("seed = " + seed) | ||
|
|
@@ -87,8 +87,10 @@ private[ml] object RandomForest extends Logging { | |
|
|
||
| val withReplacement = numTrees > 1 | ||
|
|
||
| val baggedInput = BaggedPoint | ||
| val unadjustedBaggedInput = BaggedPoint | ||
| .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) | ||
|
|
||
| val baggedInput = reweightSubSampleWeights(unadjustedBaggedInput) | ||
| .persist(StorageLevel.MEMORY_AND_DISK) | ||
|
|
||
| // depth of the decision tree | ||
|
|
@@ -823,7 +825,7 @@ private[ml] object RandomForest extends Logging { | |
| * of size (numFeatures, numBins). | ||
| */ | ||
| protected[tree] def findSplits( | ||
| input: RDD[LabeledPoint], | ||
| input: RDD[Instance], | ||
| metadata: DecisionTreeMetadata, | ||
| seed : Long): Array[Array[Split]] = { | ||
|
|
||
|
|
@@ -844,7 +846,7 @@ private[ml] object RandomForest extends Logging { | |
| logDebug("fraction of data used for calculating quantiles = " + fraction) | ||
| input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() | ||
| } else { | ||
| new Array[LabeledPoint](0) | ||
| new Array[Instance](0) | ||
| } | ||
|
|
||
| val splits = new Array[Array[Split]](numFeatures) | ||
|
|
@@ -1171,4 +1173,28 @@ private[ml] object RandomForest extends Logging { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Inject the sample weight to sub-sample weights of the baggedPoints | ||
| */ | ||
| private[impl] def reweightSubSampleWeights( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a TODO in BaggedPoint.scala for accepting weighted instances. This might be a good time to address that. If not, we will have to implement this in this JIRA, fix Bagged Point in another JIRA, and then return to this, likely in a third JIRA. Thoughts? |
||
| baggedTreePoints: RDD[BaggedPoint[TreePoint]]): RDD[BaggedPoint[TreePoint]] = { | ||
| baggedTreePoints.map {bagged => | ||
| val treePoint = bagged.datum | ||
| val adjustedSubSampleWeights = bagged.subsampleWeights.map(w => w * treePoint.weight) | ||
| new BaggedPoint[TreePoint](treePoint, adjustedSubSampleWeights) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A thin adaptor to [[org.apache.spark.mllib.tree.impl.DecisionTreeMetadata.buildMetadata]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: "adaptor" -> "adapter". I'm not sure it is currently incorrect but "adapter" is the significantly more common spelling. |
||
| */ | ||
| private[impl] def buildWeightedMetadata( | ||
| input: RDD[Instance], | ||
| strategy: OldStrategy, | ||
| numTrees: Int, | ||
| featureSubsetStrategy: String): DecisionTreeMetadata = { | ||
| val unweightedInput = input.map {w => LabeledPoint(w.label, w.features)} | ||
| DecisionTreeMetadata.buildMetadata(unweightedInput, strategy, numTrees, featureSubsetStrategy) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Space after
HasWeightCol