Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 79 additions & 30 deletions mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,101 @@

package org.apache.spark.mllib.stat

import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD

private[stat] object KernelDensity {
/**
* :: Experimental ::
* Kernel density estimation. Given a sample from a population, estimate its probability density
* function at each of the given evaluation points using kernels. Only Gaussian kernel is supported.
*
* Scala example:
*
* {{{
* val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0))
* val kd = new KernelDensity()
* .setSample(sample)
* .setBandwidth(3.0)
* val densities = kd.estimate(Array(-1.0, 2.0, 5.0))
* }}}
*/
@Experimental
class KernelDensity extends Serializable {

import KernelDensity._

/** Bandwidth of the kernel function. */
private var bandwidth: Double = 1.0

/** A sample from a population. */
private var sample: RDD[Double] = _

/**
* Given a set of samples from a distribution, estimates its density at the set of given points.
* Uses a Gaussian kernel with the given standard deviation.
* Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`).
*/
def estimate(samples: RDD[Double], standardDeviation: Double,
evaluationPoints: Array[Double]): Array[Double] = {
if (standardDeviation <= 0.0) {
throw new IllegalArgumentException("Standard deviation must be positive")
}
def setBandwidth(bandwidth: Double): this.type = {
require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.")
this.bandwidth = bandwidth
this
}

// This gets used in each Gaussian PDF computation, so compute it up front
val logStandardDeviationPlusHalfLog2Pi =
math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi)
/**
* Sets the sample to use for density estimation.
*/
def setSample(sample: RDD[Double]): this.type = {
this.sample = sample
this
}

/**
* Sets the sample to use for density estimation (for Java users).
*/
def setSample(sample: JavaRDD[java.lang.Double]): this.type = {
this.sample = sample.rdd.asInstanceOf[RDD[Double]]
this
}

/**
* Estimates probability density function at the given array of points.
*/
def estimate(points: Array[Double]): Array[Double] = {
val sample = this.sample
val bandwidth = this.bandwidth

require(sample != null, "Must set sample before calling estimate.")

val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))(
val n = points.length
// This gets used in each Gaussian PDF computation, so compute it up front
val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi)
val (densities, count) = sample.aggregate((new Array[Double](n), 0L))(
(x, y) => {
var i = 0
while (i < evaluationPoints.length) {
x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi,
evaluationPoints(i))
while (i < n) {
x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i))
i += 1
}
(x._1, i)
(x._1, n)
},
(x, y) => {
var i = 0
while (i < evaluationPoints.length) {
x._1(i) += y._1(i)
i += 1
}
blas.daxpy(n, 1.0, y._1, 1, x._1, 1)
(x._1, x._2 + y._2)
})

var i = 0
while (i < points.length) {
points(i) /= count
i += 1
}
points
blas.dscal(n, 1.0 / count, densities, 1)
densities
}
}

private object KernelDensity {

private def normPdf(mean: Double, standardDeviation: Double,
logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = {
/** Evaluates the PDF of a normal distribution. */
def normPdf(
mean: Double,
standardDeviation: Double,
logStandardDeviationPlusHalfLog2Pi: Double,
x: Double): Double = {
val x0 = x - mean
val x1 = x0 / standardDeviation
val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi
Expand Down
14 changes: 0 additions & 14 deletions mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,4 @@ object Statistics {
def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = {
ChiSqTest.chiSquaredFeatures(data)
}

/**
* Given an empirical distribution defined by the input RDD of samples, estimate its density at
* each of the given evaluation points using a Gaussian kernel.
*
* @param samples The samples RDD used to define the empirical distribution.
* @param standardDeviation The standard deviation of the kernel Gaussians.
* @param evaluationPoints The points at which to estimate densities.
* @return An array the same size as evaluationPoints with the density at each point.
*/
def kernelDensity(samples: RDD[Double], standardDeviation: Double,
evaluationPoints: Iterable[Double]): Array[Double] = {
KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

package org.apache.spark.mllib.stat

import org.scalatest.FunSuite

import org.apache.commons.math3.distribution.NormalDistribution
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext

class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
test("kernel density single sample") {
val rdd = sc.parallelize(Array(5.0))
val evaluationPoints = Array(5.0, 6.0)
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
val normal = new NormalDistribution(5.0, 3.0)
val acceptableErr = 1e-6
assert(densities(0) - normal.density(5.0) < acceptableErr)
Expand All @@ -37,7 +36,7 @@ class KernelDensitySuite extends FunSuite with MLlibTestSparkContext {
test("kernel density multiple samples") {
val rdd = sc.parallelize(Array(5.0, 10.0))
val evaluationPoints = Array(5.0, 6.0)
val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints)
val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints)
val normal1 = new NormalDistribution(5.0, 3.0)
val normal2 = new NormalDistribution(10.0, 3.0)
val acceptableErr = 1e-6
Expand Down