-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19591][ML][MLlib] Add sample weights to decision trees #16722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
ping @jkbradley @imatiach-msft |
|
Test build #72091 has finished for PR 16722 at commit
|
|
Test build #72093 has finished for PR 16722 at commit
|
|
Test build #72098 has finished for PR 16722 at commit
|
8278724 to
2112720
Compare
|
Test build #72099 has finished for PR 16722 at commit
|
|
@sethah nice changes! it looks like a test is failing, not sure if related to the PR, can you take a look? |
|
Yes the test is failing due to changes here, numerical sensitivity in some of the weight testing. I will fix this soon, but the code is still ready for review in the meantime. |
| * Private helper function for comparing two values using absolute tolerance. | ||
| */ | ||
| private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
| private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this used in the new code, maybe my search is not working properly in browser
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not used. I just changed the scope of both methods, I can change it back of course. I don't see a great reason to make this public since most users will use relTol instead. I'm open to other opinions though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I don't have a very strong opinion here either
| * the relative tolerance is meaningless, so the exception will be raised to warn users. | ||
| */ | ||
| private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { | ||
| private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might it be better to just make this public, if we are using it in tests, similar to other test methods?
| @Since("1.4.0") override val uid: String) | ||
| extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] | ||
| with DecisionTreeClassifierParams with DefaultParamsWritable { | ||
| with DecisionTreeClassifierParams with HasWeightCol with DefaultParamsWritable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was there a specific reason not to put this on the DecisionTreeClassifierParams.
it looks like the other classifiers that have this are:
LinearSVC
LogisticRegression
NaïveBayes
and regressors:
GeneralizedLinearRegression
IsotonicRegression
LinearRegression
and all have it on the params, not on the class.
However, I do agree with you that it really makes no sense for the model to have this settable, although it may be useful for users to get the information on the model.
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like by removing this method call you are removing some valuable validation logic (that exists in the base class).
specifically, this is the logic:
require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
s" $numClasses, but requires numClasses > 0.")
require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
s" dataset with invalid label $label. Labels must be integers in range" +
s" [0, $numClasses).")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Actually this problem exists elsewhere (LogisticRegression, e.g.) What to do you think about adding it back manually here and then addressing the larger issue in a separate JIRA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that's fine if it was only in one place, but I also see this pattern in DecisionTreeRegressor.scala, it seems like we should be able to refactor this part out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For regressors, extractLabeledPoints doesn't do any extra checking. The larger issue is that we are manually "extracting instances" but we have convenience methods for labeled points. Since correcting it now, in this PR, likely means implementing the framework to correct it everywhere - which is a larger and orthogonal change, I think we could just add the check manually to the classifier, then create a JIRA that addresses consolidating these, probably by adding extractInstances methods analogous their labeled point counterparts. This PR is large enough as is, without having to think about adding that method, then implementing it in all the other algos that manually extract instances, IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable, thanks for the explanation.
| s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") | ||
| } | ||
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, it looks like some validation logic is missing here
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| val instances = dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, features: Vector) => Instance(label, 1.0, features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like we aren't getting the weight column here; not sure why this file needed to be changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to pass in RDD[Instance] to RandomForest.run. I changed this back to use extractLabeledPoints
|
sorry haven't finished reviewing yet, will look more tomorrow |
|
Thanks for taking a look @imatiach-msft, much appreciated! |
|
Test build #72184 has finished for PR 16722 at commit
|
|
Test build #72220 has finished for PR 16722 at commit
|
|
jenkins retest this please |
|
Test build #72225 has finished for PR 16722 at commit
|
| s"($label,$features)" | ||
| } | ||
|
|
||
| private[spark] def toInstance: Instance = toInstance(1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is kind of a nit pick, and optional, but I would usually refactor out magic numbers like 1.0 as something like "defaultWeight" and reuse it elsewhere, but it's not really necessary in this case since it probably won't ever change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I'd prefer to remove the no arg function and be explicit everywhere. That way there is no ambiguity or unintended effects if someone changes the default value. Sound ok?
| } | ||
|
|
||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) | ||
| val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance(1.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor simplification -
it looks like this:
toInstance(1.0)
can just be simplified as:
toInstance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: since you removed the overload now this comment is no longer valid.
| dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { | ||
| case Row(label: Double, weight: Double, features: Vector) => | ||
| Instance(label, weight, features) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code above looks the same as the classifier, can we refactor somehow:
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update: it sounds like you are going to create a separate JIRA for refactoring this code, that is reasonable to me.
| MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) | ||
| val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) | ||
|
|
||
| val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify to toInstance (without the 1.0)
|
Test build #72471 has finished for PR 16722 at commit
|
| withReplacement: Boolean, | ||
| extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0, | ||
| seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { | ||
| // TODO: implement weighted bootstrapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really related to this review - is there a JIRA for this TODO - and how would it be done? Also, consider referencing the JIRA in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was some discussion on the JIRA about it. Actually, we may or may not do this, so I'll remove it in the next commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, sounds good
| require(numFeatures > 0, s"DecisionTree requires number of features > 0, " + | ||
| s"but was given an empty features vector") | ||
| val numExamples = input.count() | ||
| val (numExamples, weightSum) = input.aggregate((0L, 0.0))( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although by idea none of these changes should impact performance much, have you had any chance to verify the execution time is the same as before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I haven't. I think it's very low risk, as you say.
|
the code looks good to me, maybe a committer can comment? This is a great feature, nice work! |
| subsampleIndex += 1 | ||
| } | ||
| new BaggedPoint(instance, subsampleWeights) | ||
| new BaggedPoint(instance, subsampleCounts, 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... shouldn't the sample weight be passed instead of 1.0 in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sample weights for sampling with/without replacement are effectively not implemented yet. We'll need to think about it for RandomForest though.
|
@jkbradley might you be able to take a look at the changes from @sethah ? Thank you! |
|
Hi all, I can try to track this work now.
I think minInstancesPerNode should use sample weights, not unweighted counts. IMHO that matches the semantics of minInstancesPerNode and sample weights better: a sample with weight 2 is worth 2 samples with weight 1, so minInstancesPerNode should treat those cases in the same way. If we do that, then the old BaggedPoint should work. |
|
@sethah Would you have time to fix the conflicts so I can do a final review? It'll be great to get this into 2.2 if we can. If you are too busy, I'd be happy to take it over (though you'd be the primary PR author in the commit). Thanks! |
|
I don't think I'll have enough time before 2.2. Please feel free to take it over. I will try to help with review. Otherwise I could pick it back up if it doesn't make 2.2 |
|
OK thanks! I'll send an update soon. |
…rsWithSmallWeights ## What changes were proposed in this pull request? This is a small piece from apache#16722 which ultimately will add sample weights to decision trees. This is to allow more flexibility in testing outliers since linear models and trees behave differently. Note: The primary author when this is committed should be sethah since this is taken from his code. ## How was this patch tested? Existing tests Author: Joseph K. Bradley <[email protected]> Closes apache#17501 from jkbradley/SPARK-20183.
|
Btw, I've been working on this and just posted some thoughts about one design choice here: https://issues.apache.org/jira/browse/SPARK-9478 |
| val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { | ||
| case ((m, cnt), x) => | ||
| (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) | ||
| val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for your contribution~ I have a question about considering weight info in findSplitsForContinuousFeature here. It looks the continuous features will be influenced much more by instance weight because the weight part is considered twice: (1)make split (2) calculate impurity. Normally weight is only mentioned in impurity calculation part according to limited papers I have read. Could you provide some reference you refer here? And correct me if I misunderstand your code. :) Thanks!
|
@sethah Hi thanks for your working on it and would you have any plan to make it on 2.3 ? |
|
Hi @sethah is this something you're still working on or would it be an OK PR for someone else to take over and maybe you could do the review side together? |
|
ping @sethah |
|
Yes, feel free to take this over. |
|
Super excited to have you take this over @imatiach-msft let me know if I can be of help. |
|
Is this being taken over by #21632? |
|
@HyukjinKwon yes, I've updated this PR in #21632 |
|
Closing in favor of #21632 |
This is updated PR #16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes #21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
This is updated PR apache#16722 to latest master ## What changes were proposed in this pull request? This patch adds support for sample weights to DecisionTreeRegressor and DecisionTreeClassifier. Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr. ## How was this patch tested? The algorithms are tested to ensure that: 1. Arbitrary scaling of constant weights has no effect 2. Outliers with small weights do not affect the learned model 3. Oversampling and weighting are equivalent Unit tests are also added to test other smaller components. ## Summary of changes - Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use minInstancesPerNode. - Impurity aggregators now also hold the raw count. - This patch maintains the meaning of minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameter minWeightFractionPerNode which requires that nodes must contain at least minWeightFractionPerNode * weightedNumExamples total weight. - This patch modifies findSplitsForContinuousFeatures to use weighted sums. Unit tests are added. - TreePoint is modified to hold a sample weight - BaggedPoint is modified from: ``` Scala private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) extends Serializable ``` to ``` Scala private[spark] class BaggedPoint[Datum]( val datum: Datum, val subsampleCounts: Array[Int], val sampleWeight: Double) extends Serializable ``` We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both minInstancesPerNode and minWeightPerNode **Note**: many of the changed files are due simply to using Instance instead of LabeledPoint Closes apache#21632 from imatiach-msft/ilmat/sample-weights. Authored-by: Ilya Matiach <[email protected]> Signed-off-by: Sean Owen <[email protected]>
What changes were proposed in this pull request?
This patch adds support for sample weights to
DecisionTreeRegressorandDecisionTreeClassifier.Note: This patch does not add support for sample weights to RandomForest. As discussed in the JIRA, we would like to add sample weights into the bagging process. This patch is large enough as is, and there are some additional considerations to be made for random forests. Since the machinery introduced here needs to be present regardless, I have opted to leave random forests for a follow up pr.
How was this patch tested?
The algorithms are tested to ensure that:
Unit tests are also added to test other smaller components.
Summary of changes
Impurity aggregators now store weighted sufficient statistics. They also store a raw count, however, since this is needed to use
minInstancesPerNode.Impurity aggregators now also hold the raw count.
This patch maintains the meaning of
minInstancesPerNode, in that the parameter still corresponds to raw, unweighted counts. It also adds a new parameterminWeightFractionPerNodewhich requires that nodes must contain at leastminWeightFractionPerNode * weightedNumExamplestotal weight.This patch modifies
findSplitsForContinuousFeaturesto use weighted sums. Unit tests are added.TreePoint is modified to hold a sample weight
BaggedPoint is modified from:
to
We do not simply multiply the counts by the weight and store that because we need the raw counts and the weight in order to use both
minInstancesPerNodeandminWeightPerNodeNote: many of the changed files are due simply to using
Instanceinstead ofLabeledPoint