-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14599][ML] BaggedPoint should support sample weights. #12370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,13 +33,20 @@ import org.apache.spark.util.random.XORShiftRandom | |
| * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. | ||
| * | ||
| * @param datum Data instance | ||
| * @param subsampleWeights Weight of this instance in each subsampled dataset. | ||
| * | ||
| * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted | ||
| * dataset support, update. (We store subsampleWeights as Double for this future extension.) | ||
| * @param subsampleCounts Number of samples of this instance in each subsampled dataset. | ||
| * @param sampleWeight The weight of this instance. | ||
| */ | ||
| private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) | ||
| extends Serializable | ||
| private[spark] class BaggedPoint[Datum]( | ||
| val datum: Datum, | ||
| val subsampleCounts: Array[Int], | ||
| val sampleWeight: Double) extends Serializable { | ||
|
|
||
| /** | ||
| * Subsample counts weighted by the sample weight. | ||
| */ | ||
| def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight) | ||
|
|
||
| } | ||
|
|
||
| private[spark] object BaggedPoint { | ||
|
|
||
|
|
@@ -52,6 +59,7 @@ private[spark] object BaggedPoint { | |
| * @param subsamplingRate Fraction of the training data used for learning decision tree. | ||
| * @param numSubsamples Number of subsamples of this RDD to take. | ||
| * @param withReplacement Sampling with/without replacement. | ||
| * @param extractSampleWeight A function to get the sample weight of each point. | ||
| * @param seed Random seed. | ||
| * @return BaggedPoint dataset representation. | ||
| */ | ||
|
|
@@ -60,20 +68,24 @@ private[spark] object BaggedPoint { | |
| subsamplingRate: Double, | ||
| numSubsamples: Int, | ||
| withReplacement: Boolean, | ||
| extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just checking my understanding here, but is the intention to in future support something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that is exactly the case for this. I could not think of a better way to implement this, while still keeping bagged point generic (i.e. not requiring Datum to have a weight property or something similar). |
||
| seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { | ||
| if (withReplacement) { | ||
| convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) | ||
| convertToBaggedRDDSamplingWithReplacement(input, extractSampleWeight, subsamplingRate, | ||
| numSubsamples, seed) | ||
| } else { | ||
| if (numSubsamples == 1 && subsamplingRate == 1.0) { | ||
| convertToBaggedRDDWithoutSampling(input) | ||
| convertToBaggedRDDWithoutSampling(input, extractSampleWeight) | ||
| } else { | ||
| convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) | ||
| convertToBaggedRDDSamplingWithoutReplacement(input, extractSampleWeight, subsamplingRate, | ||
| numSubsamples, seed) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( | ||
| input: RDD[Datum], | ||
| extractSampleWeight: (Datum => Double), | ||
| subsamplingRate: Double, | ||
| numSubsamples: Int, | ||
| seed: Long): RDD[BaggedPoint[Datum]] = { | ||
|
|
@@ -82,22 +94,23 @@ private[spark] object BaggedPoint { | |
| val rng = new XORShiftRandom | ||
| rng.setSeed(seed + partitionIndex + 1) | ||
| instances.map { instance => | ||
| val subsampleWeights = new Array[Double](numSubsamples) | ||
| val subsampleCounts = new Array[Int](numSubsamples) | ||
| var subsampleIndex = 0 | ||
| while (subsampleIndex < numSubsamples) { | ||
| val x = rng.nextDouble() | ||
| subsampleWeights(subsampleIndex) = { | ||
| if (x < subsamplingRate) 1.0 else 0.0 | ||
| subsampleCounts(subsampleIndex) = { | ||
| if (x < subsamplingRate) 1 else 0 | ||
| } | ||
| subsampleIndex += 1 | ||
| } | ||
| new BaggedPoint(instance, subsampleWeights) | ||
| new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def convertToBaggedRDDSamplingWithReplacement[Datum] ( | ||
| input: RDD[Datum], | ||
| extractSampleWeight: (Datum => Double), | ||
| subsample: Double, | ||
| numSubsamples: Int, | ||
| seed: Long): RDD[BaggedPoint[Datum]] = { | ||
|
|
@@ -106,20 +119,20 @@ private[spark] object BaggedPoint { | |
| val poisson = new PoissonDistribution(subsample) | ||
| poisson.reseedRandomGenerator(seed + partitionIndex + 1) | ||
| instances.map { instance => | ||
| val subsampleWeights = new Array[Double](numSubsamples) | ||
| val subsampleCounts = new Array[Int](numSubsamples) | ||
| var subsampleIndex = 0 | ||
| while (subsampleIndex < numSubsamples) { | ||
| subsampleWeights(subsampleIndex) = poisson.sample() | ||
| subsampleCounts(subsampleIndex) = poisson.sample() | ||
| subsampleIndex += 1 | ||
| } | ||
| new BaggedPoint(instance, subsampleWeights) | ||
| new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def convertToBaggedRDDWithoutSampling[Datum] ( | ||
| input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { | ||
| input.map(datum => new BaggedPoint(datum, Array(1.0))) | ||
| input: RDD[Datum], | ||
| extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = { | ||
| input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum))) | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a
val?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this as a convenience method. If we make it a val then we add storage overhead in the class which is redundant. If preferable, we could remove it entirely.