Skip to content

Conversation

@sethah
Copy link
Contributor

@sethah sethah commented Jan 27, 2017

What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.

How was this patch tested?

The algorithms are tested to ensure that:

  1. Arbitrary scaling of constant weights has no effect
  2. Outliers with small weights do not affect the learned model
  3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.

Summary of changes

  • Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

  • Impurity aggregators now also hold the raw count.

  • This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

  • This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

  • TreePoint is modified to hold a sample weight

  • BaggedPoint is modified from:

private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable

to

private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable

We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

Note: many of the changed files are due simply to using Instance instead of LabeledPoint

@sethah
Copy link
Contributor Author

sethah commented Jan 27, 2017

ping @jkbradley @imatiach-msft

@SparkQA
Copy link

SparkQA commented Jan 28, 2017

Test build #72091 has finished for PR 16722 at commit 7dc1437.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 28, 2017

Test build #72093 has finished for PR 16722 at commit 2729a63.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 28, 2017

Test build #72098 has finished for PR 16722 at commit 8278724.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jan 28, 2017

Test build #72099 has finished for PR 16722 at commit 2112720.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@imatiach-msft
Copy link
Contributor

@sethah nice changes! it looks like a test is failing, not sure if related to the PR, can you take a look?

@sethah
Copy link
Contributor Author

sethah commented Jan 30, 2017

Yes the test is failing due to changes here, numerical sensitivity in some of the weight testing. I will fix this soon, but the code is still ready for review in the meantime.

* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this used in the new code, maybe my search is not working properly in browser

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used. I just changed the scope of both methods, I can change it back of course. I don't see a great reason to make this public since most users will use relTol instead. I'm open to other opinions though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I don't have a very strong opinion here either

* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might it be better to just make this public, if we are using it in tests, similar to other test methods?

@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeClassifierParams with DefaultParamsWritable {
with DecisionTreeClassifierParams with HasWeightCol with DefaultParamsWritable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a specific reason not to put this on the DecisionTreeClassifierParams.
it looks like the other classifiers that have this are:
LinearSVC
LogisticRegression
NaïveBayes
and regressors:
GeneralizedLinearRegression
IsotonicRegression
LinearRegression
and all have it on the params, not on the class.
However, I do agree with you that it really makes no sense for the model to have this settable, although it may be useful for users to get the information on the model.

s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like by removing this method call you are removing some valuable validation logic (that exists in the base class).
specifically, this is the logic:
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")

require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Actually this problem exists elsewhere (LogisticRegression, e.g.) What to do you think about adding it back manually here and then addressing the larger issue in a separate JIRA?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that's fine if it was only in one place, but I also see this pattern in DecisionTreeRegressor.scala, it seems like we should be able to refactor this part out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For regressors, extractLabeledPoints doesn't do any extra checking. The larger issue is that we are manually "extracting instances" but we have convenience methods for labeled points. Since correcting it now, in this PR, likely means implementing the framework to correct it everywhere - which is a larger and orthogonal change, I think we could just add the check manually to the classifier, then create a JIRA that addresses consolidating these, probably by adding extractInstances methods analogous their labeled point counterparts. This PR is large enough as is, without having to think about adding that method, then implementing it in all the other algos that manually extract instances, IMO.

Copy link
Contributor

@imatiach-msft imatiach-msft Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable, thanks for the explanation.

s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, it looks like some validation logic is missing here


val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) => Instance(label, 1.0, features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like we aren't getting the weight column here; not sure why this file needed to be changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to pass in RDD[Instance] to RandomForest.run. I changed this back to use extractLabeledPoints

@imatiach-msft
Copy link
Contributor

sorry haven't finished reviewing yet, will look more tomorrow

@sethah
Copy link
Contributor Author

sethah commented Jan 31, 2017

Thanks for taking a look @imatiach-msft, much appreciated!

@SparkQA
Copy link

SparkQA commented Jan 31, 2017

Test build #72184 has finished for PR 16722 at commit 159c5a6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 1, 2017

Test build #72220 has finished for PR 16722 at commit 1db8494.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor Author

sethah commented Feb 1, 2017

jenkins retest this please

@SparkQA
Copy link

SparkQA commented Feb 1, 2017

Test build #72225 has finished for PR 16722 at commit 1db8494.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

s"($label,$features)"
}

private[spark] def toInstance: Instance = toInstance(1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kind of a nit pick, and optional, but I would usually refactor out magic numbers like 1.0 as something like "defaultWeight" and reuse it elsewhere, but it's not really necessary in this case since it probably won't ever change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'd prefer to remove the no arg function and be explicit everywhere. That way there is no ambiguity or unintended effects if someone changes the default value. Sound ok?

}

val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance(1.0))
Copy link
Contributor

@imatiach-msft imatiach-msft Feb 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor simplification -
it looks like this:
toInstance(1.0)
can just be simplified as:
toInstance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: since you removed the overload now this comment is no longer valid.

dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code above looks the same as the classifier, can we refactor somehow:

val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) 
val instances = 
  dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { 
    case Row(label: Double, weight: Double, features: Vector) => 
      Instance(label, weight, features) 

Copy link
Contributor

@imatiach-msft imatiach-msft Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: it sounds like you are going to create a separate JIRA for refactoring this code, that is reasonable to me.

MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)

val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify to toInstance (without the 1.0)

@SparkQA
Copy link

SparkQA commented Feb 6, 2017

Test build #72471 has finished for PR 16722 at commit 48b1258.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah sethah changed the title [SPARK-9478][ML][MLlib] Add sample weights to decision trees [SPARK-19591][ML][MLlib] Add sample weights to decision trees Feb 14, 2017
withReplacement: Boolean,
extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
// TODO: implement weighted bootstrapping
Copy link
Contributor

@imatiach-msft imatiach-msft Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really related to this review - is there a JIRA for this TODO - and how would it be done? Also, consider referencing the JIRA in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some discussion on the JIRA about it. Actually, we may or may not do this, so I'll remove it in the next commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, sounds good

require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
s"but was given an empty features vector")
val numExamples = input.count()
val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although by idea none of these changes should impact performance much, have you had any chance to verify the execution time is the same as before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I haven't. I think it's very low risk, as you say.

@imatiach-msft
Copy link
Contributor

imatiach-msft commented Feb 15, 2017

the code looks good to me, maybe a committer can comment? This is a great feature, nice work!

subsampleIndex += 1
}
new BaggedPoint(instance, subsampleWeights)
new BaggedPoint(instance, subsampleCounts, 1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm... shouldn't the sample weight be passed instead of 1.0 in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sample weights for sampling with/without replacement are effectively not implemented yet. We'll need to think about it for RandomForest though.

@imatiach-msft
Copy link
Contributor

@jkbradley might you be able to take a look at the changes from @sethah ? Thank you!

@jkbradley
Copy link
Member

Hi all, I can try to track this work now.

This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

I think minInstancesPerNode should use sample weights, not unweighted counts. IMHO that matches the semantics of minInstancesPerNode and sample weights better: a sample with weight 2 is worth 2 samples with weight 1, so minInstancesPerNode should treat those cases in the same way.

If we do that, then the old BaggedPoint should work.

@jkbradley
Copy link
Member

@sethah Would you have time to fix the conflicts so I can do a final review? It'll be great to get this into 2.2 if we can. If you are too busy, I'd be happy to take it over (though you'd be the primary PR author in the commit). Thanks!

@sethah
Copy link
Contributor Author

sethah commented Mar 30, 2017

I don't think I'll have enough time before 2.2. Please feel free to take it over. I will try to help with review. Otherwise I could pick it back up if it doesn't make 2.2

@jkbradley
Copy link
Member

OK thanks! I'll send an update soon.

ghost pushed a commit to dbtsai/spark that referenced this pull request Apr 5, 2017
…rsWithSmallWeights

## What changes were proposed in this pull request?

This is a small piece from apache#16722 which ultimately will add sample weights to decision trees.  This is to allow more flexibility in testing outliers since linear models and trees behave differently.

Note: The primary author when this is committed should be sethah since this is taken from his code.

## How was this patch tested?

Existing tests

Author: Joseph K. Bradley <[email protected]>

Closes apache#17501 from jkbradley/SPARK-20183.
@jkbradley
Copy link
Member

Btw, I've been working on this and just posted some thoughts about one design choice here: https://issues.apache.org/jira/browse/SPARK-9478

val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
case ((m, cnt), x) =>
(m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for your contribution~ I have a question about considering weight info in findSplitsForContinuousFeature here. It looks the continuous features will be influenced much more by instance weight because the weight part is considered twice: (1)make split (2) calculate impurity. Normally weight is only mentioned in impurity calculation part according to limited papers I have read. Could you provide some reference you refer here? And correct me if I misunderstand your code. :) Thanks!

@zengxy
Copy link

zengxy commented Nov 1, 2017

@sethah Hi thanks for your working on it and would you have any plan to make it on 2.3 ?

@holdenk
Copy link
Contributor

holdenk commented Jun 13, 2018

Hi @sethah is this something you're still working on or would it be an OK PR for someone else to take over and maybe you could do the review side together?

@holdenk
Copy link
Contributor

holdenk commented Jun 22, 2018

ping @sethah

@imatiach-msft
Copy link
Contributor

@holdenk @sethah I'd be happy to help out in my spare time and take over the PR - it looks like the PR just needs to be updated to latest code, are there any other changes required? Thanks!

@sethah
Copy link
Contributor Author

sethah commented Jun 22, 2018

Yes, feel free to take this over.

@holdenk
Copy link
Contributor

holdenk commented Jun 28, 2018

Super excited to have you take this over @imatiach-msft let me know if I can be of help.

@HyukjinKwon
Copy link
Member

Is this being taken over by #21632?

@imatiach-msft
Copy link
Contributor

@HyukjinKwon yes, I've updated this PR in #21632

@srowen
Copy link
Member

srowen commented Jan 7, 2019

Closing in favor of #21632

@srowen srowen closed this Jan 7, 2019
srowen pushed a commit that referenced this pull request Jan 25, 2019
This is updated PR #16722 to latest master

## What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
## How was this patch tested?

The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.
## Summary of changes

   - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

   - Impurity aggregators now also hold the raw count.

   - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

   - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

   - TreePoint is modified to hold a sample weight

   - BaggedPoint is modified from:
``` Scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable
```
to
``` Scala
private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable
```
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

**Note**: many of the changed files are due simply to using Instance instead of LabeledPoint

Closes #21632 from imatiach-msft/ilmat/sample-weights.

Authored-by: Ilya Matiach <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
jackylee-ch pushed a commit to jackylee-ch/spark that referenced this pull request Feb 18, 2019
This is updated PR apache#16722 to latest master

## What changes were proposed in this pull request?

This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier.

Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
## How was this patch tested?

The algorithms are tested to ensure that:
    1. Arbitrary scaling of constant weights has no effect
    2. Outliers with small weights do not affect the learned model
    3. Oversampling and weighting are equivalent

Unit tests are also added to test other smaller components.
## Summary of changes

   - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode.

   - Impurity aggregators now also hold the raw count.

   - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight.

   - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added.

   - TreePoint is modified to hold a sample weight

   - BaggedPoint is modified from:
``` Scala
private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable
```
to
``` Scala
private[spark] class BaggedPoint[Datum](
    val datum: Datum,
    val subsampleCounts: Array[Int],
    val sampleWeight: Double) extends Serializable
```
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode

**Note**: many of the changed files are due simply to using Instance instead of LabeledPoint

Closes apache#21632 from imatiach-msft/ilmat/sample-weights.

Authored-by: Ilya Matiach <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants