diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 4e372702f0c65..cad73fe6cf9f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -33,13 +33,20 @@ import org.apache.spark.util.random.XORShiftRandom * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * * @param datum Data instance - * @param subsampleWeights Weight of this instance in each subsampled dataset. - * - * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted - * dataset support, update. (We store subsampleWeights as Double for this future extension.) + * @param subsampleCounts Number of samples of this instance in each subsampled dataset. + * @param sampleWeight The weight of this instance. */ -private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) - extends Serializable +private[spark] class BaggedPoint[Datum]( + val datum: Datum, + val subsampleCounts: Array[Int], + val sampleWeight: Double) extends Serializable { + + /** + * Subsample counts weighted by the sample weight. + */ + def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight) + +} private[spark] object BaggedPoint { @@ -52,6 +59,7 @@ private[spark] object BaggedPoint { * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param numSubsamples Number of subsamples of this RDD to take. * @param withReplacement Sampling with/without replacement. + * @param extractSampleWeight A function to get the sample weight of each point. * @param seed Random seed. * @return BaggedPoint dataset representation. */ @@ -60,20 +68,24 @@ private[spark] object BaggedPoint { subsamplingRate: Double, numSubsamples: Int, withReplacement: Boolean, + extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0, seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { if (withReplacement) { - convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + convertToBaggedRDDSamplingWithReplacement(input, extractSampleWeight, subsamplingRate, + numSubsamples, seed) } else { if (numSubsamples == 1 && subsamplingRate == 1.0) { - convertToBaggedRDDWithoutSampling(input) + convertToBaggedRDDWithoutSampling(input, extractSampleWeight) } else { - convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + convertToBaggedRDDSamplingWithoutReplacement(input, extractSampleWeight, subsamplingRate, + numSubsamples, seed) } } } private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( input: RDD[Datum], + extractSampleWeight: (Datum => Double), subsamplingRate: Double, numSubsamples: Int, seed: Long): RDD[BaggedPoint[Datum]] = { @@ -82,22 +94,23 @@ private[spark] object BaggedPoint { val rng = new XORShiftRandom rng.setSeed(seed + partitionIndex + 1) instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) + val subsampleCounts = new Array[Int](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { val x = rng.nextDouble() - subsampleWeights(subsampleIndex) = { - if (x < subsamplingRate) 1.0 else 0.0 + subsampleCounts(subsampleIndex) = { + if (x < subsamplingRate) 1 else 0 } subsampleIndex += 1 } - new BaggedPoint(instance, subsampleWeights) + new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance)) } } } private def convertToBaggedRDDSamplingWithReplacement[Datum] ( input: RDD[Datum], + extractSampleWeight: (Datum => Double), subsample: Double, numSubsamples: Int, seed: Long): RDD[BaggedPoint[Datum]] = { @@ -106,20 +119,20 @@ private[spark] object BaggedPoint { val poisson = new PoissonDistribution(subsample) poisson.reseedRandomGenerator(seed + partitionIndex + 1) instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) + val subsampleCounts = new Array[Int](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.sample() + subsampleCounts(subsampleIndex) = poisson.sample() subsampleIndex += 1 } - new BaggedPoint(instance, subsampleWeights) + new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance)) } } } private def convertToBaggedRDDWithoutSampling[Datum] ( - input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { - input.map(datum => new BaggedPoint(datum, Array(1.0))) + input: RDD[Datum], + extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum))) } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 7b1fd089f2943..d0cb880625fde 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -125,7 +125,8 @@ private[spark] object RandomForest extends Logging { val withReplacement = numTrees > 1 val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, + seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree @@ -407,7 +408,7 @@ private[spark] object RandomForest extends Logging { if (nodeInfo != null) { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + val instanceWeight = baggedPoint.subsampleCounts(treeIndex) if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala index 77ab3d8bb75f7..ec20148828afc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -29,9 +29,9 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, seed = 42) baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + assert(baggedPoint.subsampleCounts.size == 1 && baggedPoint.subsampleCounts(0) == 1) } } @@ -43,8 +43,9 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } @@ -59,8 +60,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = + BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } @@ -74,8 +77,9 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } @@ -90,8 +94,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, + seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) }