Skip to content

Conversation

@rotationsymmetry
Copy link
Contributor

This PR adds weight support to the following Predictors in ML.

DecisionTreeClassifier
DecisionTreeRegressor
RandomForestClassifier
RandomForestRegressor

cc @jkbradley

@SparkQA
Copy link

SparkQA commented Oct 7, 2015

Test build #43318 has finished for PR 9008 at commit 367a443.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should specify return type here.

Is the reason that you can't just modify buildMetadata to accept and RDD[WeightedLabeledPoint] because you are trying not to change MLlib implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your comment.

  1. I will add the return type in my next push.

  2. yes, you are right, I don't want to change the mllib impl yet. I will leave it as a TODO after we have a standard way to represent weighted label point.

@SparkQA
Copy link

SparkQA commented Oct 9, 2015

Test build #43460 has finished for PR 9008 at commit 0ffbdd0.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 9, 2015

Test build #43482 timed out for PR 9008 at commit 33982fb after a configured wait of 250m.

@SparkQA
Copy link

SparkQA commented Oct 10, 2015

Test build #43503 has finished for PR 9008 at commit 3273ed4.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class ChildProcAppHandle implements SparkAppHandle
    • abstract class LauncherConnection implements Closeable, Runnable
    • final class LauncherProtocol
    • static class Message implements Serializable
    • static class Hello extends Message
    • static class SetAppId extends Message
    • static class SetState extends Message
    • static class Stop extends Message
    • class LauncherServer implements Closeable
    • class NamedThreadFactory implements ThreadFactory
    • class OutputRedirector

@SparkQA
Copy link

SparkQA commented Oct 10, 2015

Test build #43531 timed out for PR 9008 at commit 8f35057 after a configured wait of 250m.

@SparkQA
Copy link

SparkQA commented Oct 12, 2015

Test build #43553 has finished for PR 9008 at commit bd316d6.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@rotationsymmetry
Copy link
Contributor Author

Jenkins failed tests unrelated to this patch. Let's try again.

@SparkQA
Copy link

SparkQA commented Oct 12, 2015

Test build #43572 has finished for PR 9008 at commit 822382e.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 13, 2015

Test build #43588 has finished for PR 9008 at commit c1785a8.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@rotationsymmetry
Copy link
Contributor Author

@sethah I have incorporated your comments in the latest patch. Thank you!

@jkbradley Do you have any comments or suggestions? Much appreciated.

@SparkQA
Copy link

SparkQA commented Nov 6, 2015

Test build #45206 has finished for PR 9008 at commit c1785a8.

  • This patch fails Spark unit tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@rotationsymmetry
Copy link
Contributor Author

retest this please

@SparkQA
Copy link

SparkQA commented Nov 8, 2015

Test build #45315 has finished for PR 9008 at commit 32f4548.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):\n * abstract class Writer extends BaseReadWrite\n * trait Writable\n * abstract class Reader[T] extends BaseReadWrite\n * trait Readable[T]\n * case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)\n * case class Expand(\n

final class RandomForestClassifier(override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams {
with RandomForestParams with TreeClassifierParams with HasWeightCol{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Space after HasWeightCol

@fabboe
Copy link

fabboe commented Feb 16, 2016

Thanks for working on this!

Minor: PR title says class weights but actually it's sample weights what is implemented.

.setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

val model1 = pipeline.fit(dataset)
val model2 = pipeline.fit(dataset, rf.weightCol->"weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: space around ->

/**
* Inject the sample weight to sub-sample weights of the baggedPoints
*/
private[impl] def reweightSubSampleWeights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a TODO in BaggedPoint.scala for accepting weighted instances. This might be a good time to address that. If not, we will have to implement this in this JIRA, fix Bagged Point in another JIRA, and then return to this, likely in a third JIRA. Thoughts?

@sethah
Copy link
Contributor

sethah commented Feb 16, 2016

@rotationsymmetry I made a pass on this, mostly minor comments. Thanks for working on this, it would be great to get it merged in!

@rotationsymmetry
Copy link
Contributor Author

@sethah Thank you very much for your review. I will incorporate the changes in the next few days. Regarding the TODO in BaggedPoint.scala, I want to look into the details to find out the scope of the change.

@sethah
Copy link
Contributor

sethah commented Feb 23, 2016

I noticed a problem with the current implementation regarding the minInstancesPerNode parameter. The number of instances in each node is now a weighted count where the weights can have an arbitrary scale. For example, a tree built with uniform weights where each weight is equal to 1.0 will build a different tree than uniform weights where each weight is 1.0 / N (N is number of samples). I suppose there are a number of ways to mitigate this.

I checked scikit-learn and they track the actual raw sample counts (unweighted) as well as the sample weights. They use min_samples_leaf to compute validity based on raw counts, and min_weight_fraction_leaf to compute validity based on weighted counts. This will not be possible under the current implementation here because we lose the raw counts when we convert to unadjustedBaggedInput to baggedInput. We could compare weighted split counts vs minInstancesPerNode / N where N is number of training samples, or we could adjust the BaggedPoint class to store counts and weight and proceed ala scikit-learn. I'm not sure what is best, thoughts?

@sethah
Copy link
Contributor

sethah commented Feb 23, 2016

Another issue is that the information gain for candidate splits is not computed correctly with fractional samples. This is because the information gain calculation here uses the sample counts which are converted to Long type. This produces incorrect results in general, and NaN values when the total count is less than 1. The count function here should return a Double type instead. Can we add a test to ensure that the trees are invariant under constant multiplication of the weights?

@sethah
Copy link
Contributor

sethah commented Mar 8, 2016

@rotationsymmetry: Will you have time to work on this? I am more than happy to send a PR to your PR if you do not have time.

@jkbradley @dbtsai Would you mind chiming in on the issue mentioned above about minimum instances per node?

@sethah
Copy link
Contributor

sethah commented Mar 18, 2016

cc @MLnick thoughts on the above comments?

@holdenk
Copy link
Contributor

holdenk commented Apr 12, 2016

@sethah So to avoid adding any overhead from computing stats for both these params one option would be to selectively compute only the stats that are required (e.g. if they request minInstancesPerNode per node request that and if they requeust min_weight_fraction_leaf compute the stats needed).

@sethah
Copy link
Contributor

sethah commented Apr 14, 2016

@holdenk Thanks for the feedback. Upon some further thought, I think that a.) We need to compute the statistics needed for both minInstancesPerNode and minWeightFractionPerNode and b.) it will not be too hard to compute them both and will not add a ton of extra memory overhead. Selectively computing one or the other could get complicated very quickly.

I am going to have a PR for this ready soon, which will incorporate changes submitted in this PR. I created two JIRAs for issues that I encountered when preparing this PR and submitted patches for each. They are:

@sethah
Copy link
Contributor

sethah commented Oct 12, 2016

@rotationsymmetry Could you please close this?

srowen added a commit to srowen/spark that referenced this pull request Oct 31, 2016
@asfgit asfgit closed this in 26b07f1 Oct 31, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants