Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.stat

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.stat

/**
* :: DeveloperApi ::
Expand All @@ -39,7 +40,14 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
*/
@Since("1.1.0")
@DeveloperApi
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
class MultivariateOnlineSummarizer(mask: Int)
extends MultivariateStatisticalSummary with Serializable {

import MultivariateOnlineSummarizer._
def this() = {
this(MultivariateOnlineSummarizer.allMask)
}
private def testMask(m: Int): Boolean = (mask & m) != 0

private var n = 0
private var currMean: Array[Double] = _
Expand Down Expand Up @@ -71,14 +79,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(instance.size > 0, s"Vector should have dimension larger than zero.")
n = instance.size

currMean = Array.ofDim[Double](n)
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
weightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
if(testMask(currMeanMask)) currMean = Array.ofDim[Double](n)
if(testMask(currM2nMask)) currM2n = Array.ofDim[Double](n)
if(testMask(currM2Mask)) currM2 = Array.ofDim[Double](n)
if(testMask(currL1Mask)) currL1 = Array.ofDim[Double](n)
if(testMask(weightSumMask)) weightSum = Array.ofDim[Double](n)
if(testMask(nnzMask)) nnz = Array.ofDim[Long](n)
if(testMask(currMaxMask)) currMax = Array.fill[Double](n)(Double.MinValue)
if(testMask(currMinMask)) currMin = Array.fill[Double](n)(Double.MaxValue)
}

require(n == instance.size, s"Dimensions mismatch when adding new sample." +
Expand All @@ -94,22 +102,26 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrMin = currMin
instance.foreachActive { (index, value) =>
if (value != 0.0) {
if (localCurrMax(index) < value) {
if (testMask(currMaxMask) && localCurrMax(index) < value) {
localCurrMax(index) = value
}
if (localCurrMin(index) > value) {
if (testMask(currMinMask) && localCurrMin(index) > value) {
localCurrMin(index) = value
}

val prevMean = localCurrMean(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
localCurrM2(index) += weight * value * value
localCurrL1(index) += weight * math.abs(value)
if (testMask(currMeanMask)) {
val prevMean = localCurrMean(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
if (testMask(currM2nMask)) {
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
}
}
if(testMask(currM2Mask)) localCurrM2(index) += weight * value * value
if(testMask(currL1Mask)) localCurrL1(index) += weight * math.abs(value)

localWeightSum(index) += weight
localNumNonzeros(index) += 1
if(testMask(weightSumMask)) localWeightSum(index) += weight
if(testMask(nnzMask)) localNumNonzeros(index) += 1
}
}

Expand All @@ -136,41 +148,52 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
val thisNnz = weightSum(i)
val otherNnz = other.weightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherNnz / totalNnz
// merge m2n together
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
// merge m2 together
currM2(i) += other.currM2(i)
// merge l1 together
currL1(i) += other.currL1(i)
// merge max and min
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
if (testMask(weightSumMask)) {
val thisWeightSum = weightSum(i)
val otherWeightSum = other.weightSum(i)
val totalWeightSum = thisWeightSum + otherWeightSum
if (totalWeightSum != 0.0) {
if (testMask(currMeanMask)) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherWeightSum / totalWeightSum
// merge m2n together
if (testMask(currM2nMask)) {
currM2n(i) += other.currM2n(i) +
deltaMean * deltaMean * thisWeightSum * otherWeightSum / totalWeightSum
}
}
// merge m2 together
if(testMask(currM2Mask)) currM2(i) += other.currM2(i)
// merge l1 together
if(testMask(currL1Mask)) currL1(i) += other.currL1(i)
}
weightSum(i) = totalWeightSum
}
if (testMask(nnzMask)) {
val totalNnz = nnz(i) + other.nnz(i)
if (totalNnz != 0) {
// merge max and min
if (testMask(currMaxMask)) currMax(i) = math.max(currMax(i), other.currMax(i))
if (testMask(currMinMask)) currMin(i) = math.min(currMin(i), other.currMin(i))
}
nnz(i) = totalNnz
}
weightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
} else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) {
this.n = other.n
this.currMean = other.currMean.clone()
this.currM2n = other.currM2n.clone()
this.currM2 = other.currM2.clone()
this.currL1 = other.currL1.clone()
if (testMask(currMeanMask)) this.currMean = other.currMean.clone()
if (testMask(currM2nMask)) this.currM2n = other.currM2n.clone()
if (testMask(currM2Mask)) this.currM2 = other.currM2.clone()
if (testMask(currL1Mask)) this.currL1 = other.currL1.clone()
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
this.weightSum = other.weightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
if (testMask(weightSumMask)) this.weightSum = other.weightSum.clone()
if (testMask(nnzMask)) this.nnz = other.nnz.clone()
if (testMask(currMaxMask)) this.currMax = other.currMax.clone()
if (testMask(currMinMask)) this.currMin = other.currMin.clone()
}
this
}
Expand Down Expand Up @@ -298,3 +321,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
Vectors.dense(currL1)
}
}

object MultivariateOnlineSummarizer {

private val currMeanMask = 0x1
private val currM2nMask = 0x2
private val currM2Mask = 0x4
private val currL1Mask = 0x8
private val weightSumMask = 0x10
private val nnzMask = 0x20
private val currMaxMask = 0x40
private val currMinMask = 0x80

val meanMask = currMeanMask | weightSumMask
val varianceMask = currMeanMask | currM2nMask | weightSumMask
val numNonZerosMask = nnzMask
val maxMask = nnzMask | currMaxMask
val minMask = nnzMask | currMinMask

val allMask = 0xFFFFFFFF
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,34 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}

test("mask test") {
import MultivariateOnlineSummarizer._

assert(new MultivariateOnlineSummarizer(meanMask)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0)).mean
~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch")

assert(new MultivariateOnlineSummarizer(minMask)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0)).min
~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch")

assert(new MultivariateOnlineSummarizer(maxMask)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0)).max
~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch")

assert(new MultivariateOnlineSummarizer(numNonZerosMask)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0)).numNonzeros
~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch")

assert(new MultivariateOnlineSummarizer(varianceMask)
.add(Vectors.dense(-1.0, 0.0, 6.0))
.add(Vectors.dense(3.0, -3.0, 0.0)).variance
~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch")
}

}