Skip to content

Commit 947ea1c

Browse files
committed
[SPARK-7753] [MLLIB] Update KernelDensity API
Update `KernelDensity` API to make it extensible to different kernels in the future. `bandwidth` is used instead of `standardDeviation`. The static `kernelDensity` method is removed from `Statistics`. The implementation is updated using BLAS, while the algorithm remains the same. sryza srowen Author: Xiangrui Meng <[email protected]> Closes #6279 from mengxr/SPARK-7753 and squashes the following commits: 4cdfadc [Xiangrui Meng] add example code in the doc 767fd5a [Xiangrui Meng] update KernelDensity API
1 parent 8ddcb25 commit 947ea1c

File tree

3 files changed

+82
-48
lines changed

3 files changed

+82
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,101 @@
1717

1818
package org.apache.spark.mllib.stat
1919

20+
import com.github.fommil.netlib.BLAS.{getInstance => blas}
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.api.java.JavaRDD
2024
import org.apache.spark.rdd.RDD
2125

22-
private[stat] object KernelDensity {
26+
/**
27+
* :: Experimental ::
28+
* Kernel density estimation. Given a sample from a population, estimate its probability density
29+
* function at each of the given evaluation points using kernels. Only Gaussian kernel is supported.
30+
*
31+
* Scala example:
32+
*
33+
* {{{
34+
* val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0))
35+
* val kd = new KernelDensity()
36+
* .setSample(sample)
37+
* .setBandwidth(3.0)
38+
* val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
39+
* }}}
40+
*/
41+
@Experimental
42+
class KernelDensity extends Serializable {
43+
44+
import KernelDensity._
45+
46+
/** Bandwidth of the kernel function. */
47+
private var bandwidth: Double = 1.0
48+
49+
/** A sample from a population. */
50+
private var sample: RDD[Double] = _
51+
2352
/**
24-
* Given a set of samples from a distribution, estimates its density at the set of given points.
25-
* Uses a Gaussian kernel with the given standard deviation.
53+
* Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
2654
*/
27-
def estimate(samples: RDD[Double], standardDeviation: Double,
28-
evaluationPoints: Array[Double]): Array[Double] = {
29-
if (standardDeviation <= 0.0) {
30-
throw new IllegalArgumentException("Standard deviation must be positive")
31-
}
55+
def setBandwidth(bandwidth: Double): this.type = {
56+
require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
57+
this.bandwidth = bandwidth
58+
this
59+
}
3260

33-
// This gets used in each Gaussian PDF computation, so compute it up front
34-
val logStandardDeviationPlusHalfLog2Pi =
35-
math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi)
61+
/**
62+
* Sets the sample to use for density estimation.
63+
*/
64+
def setSample(sample: RDD[Double]): this.type = {
65+
this.sample = sample
66+
this
67+
}
68+
69+
/**
70+
* Sets the sample to use for density estimation (for Java users).
71+
*/
72+
def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
73+
this.sample = sample.rdd.asInstanceOf[RDD[Double]]
74+
this
75+
}
76+
77+
/**
78+
* Estimates probability density function at the given array of points.
79+
*/
80+
def estimate(points: Array[Double]): Array[Double] = {
81+
val sample = this.sample
82+
val bandwidth = this.bandwidth
83+
84+
require(sample != null, "Must set sample before calling estimate.")
3685

37-
val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
86+
val n = points.length
87+
// This gets used in each Gaussian PDF computation, so compute it up front
88+
val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi)
89+
val (densities, count) = sample.aggregate((new Array[Double](n), 0L))(
3890
(x, y) => {
3991
var i = 0
40-
while (i < evaluationPoints.length) {
41-
x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
42-
evaluationPoints(i))
92+
while (i < n) {
93+
x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i))
4394
i += 1
4495
}
45-
(x._1, i)
96+
(x._1, n)
4697
},
4798
(x, y) => {
48-
var i = 0
49-
while (i < evaluationPoints.length) {
50-
x._1(i) += y._1(i)
51-
i += 1
52-
}
99+
blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
53100
(x._1, x._2 + y._2)
54101
})
55-
56-
var i = 0
57-
while (i < points.length) {
58-
points(i) /= count
59-
i += 1
60-
}
61-
points
102+
blas.dscal(n, 1.0 / count, densities, 1)
103+
densities
62104
}
105+
}
106+
107+
private object KernelDensity {
63108

64-
private def normPdf(mean: Double, standardDeviation: Double,
65-
logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
109+
/** Evaluates the PDF of a normal distribution. */
110+
def normPdf(
111+
mean: Double,
112+
standardDeviation: Double,
113+
logStandardDeviationPlusHalfLog2Pi: Double,
114+
x: Double): Double = {
66115
val x0 = x - mean
67116
val x1 = x0 / standardDeviation
68117
val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi

mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,4 @@ object Statistics {
149149
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
150150
ChiSqTest.chiSquaredFeatures(data)
151151
}
152-
153-
/**
154-
* Given an empirical distribution defined by the input RDD of samples, estimate its density at
155-
* each of the given evaluation points using a Gaussian kernel.
156-
*
157-
* @param samples The samples RDD used to define the empirical distribution.
158-
* @param standardDeviation The standard deviation of the kernel Gaussians.
159-
* @param evaluationPoints The points at which to estimate densities.
160-
* @return An array the same size as evaluationPoints with the density at each point.
161-
*/
162-
def kernelDensity(samples: RDD[Double], standardDeviation: Double,
163-
evaluationPoints: Iterable[Double]): Array[Double] = {
164-
KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
165-
}
166152
}

mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717

1818
package org.apache.spark.mllib.stat
1919

20-
import org.scalatest.FunSuite
21-
2220
import org.apache.commons.math3.distribution.NormalDistribution
21+
import org.scalatest.FunSuite
2322

2423
import org.apache.spark.mllib.util.MLlibTestSparkContext
2524

2625
class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
2726
test("kernel density single sample") {
2827
val rdd = sc.parallelize(Array(5.0))
2928
val evaluationPoints = Array(5.0, 6.0)
30-
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
29+
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
3130
val normal = new NormalDistribution(5.0, 3.0)
3231
val acceptableErr = 1e-6
3332
assert(densities(0) - normal.density(5.0) < acceptableErr)
@@ -37,7 +36,7 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
3736
test("kernel density multiple samples") {
3837
val rdd = sc.parallelize(Array(5.0, 10.0))
3938
val evaluationPoints = Array(5.0, 6.0)
40-
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
39+
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
4140
val normal1 = new NormalDistribution(5.0, 3.0)
4241
val normal2 = new NormalDistribution(10.0, 3.0)
4342
val acceptableErr = 1e-6

0 commit comments

Comments
 (0)