From 302946821db75117b4ab2346b4b445472ed50eb4 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 31 Aug 2016 18:41:03 -0700 Subject: [PATCH] update v3 --- .../stat/MultivariateOnlineSummarizer.scala | 135 ++++++++++++------ .../MultivariateOnlineSummarizerSuite.scala | 30 ++++ 2 files changed, 119 insertions(+), 46 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7a2a7a35a91c..186a5b585443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.stat /** * :: DeveloperApi :: @@ -39,7 +40,14 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} */ @Since("1.1.0") @DeveloperApi -class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { +class MultivariateOnlineSummarizer(mask: Int) + extends MultivariateStatisticalSummary with Serializable { + + import MultivariateOnlineSummarizer._ + def this() = { + this(MultivariateOnlineSummarizer.allMask) + } + private def testMask(m: Int): Boolean = (mask & m) != 0 private var n = 0 private var currMean: Array[Double] = _ @@ -71,14 +79,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(instance.size > 0, s"Vector should have dimension larger than zero.") n = instance.size - currMean = Array.ofDim[Double](n) - currM2n = Array.ofDim[Double](n) - currM2 = Array.ofDim[Double](n) - currL1 = Array.ofDim[Double](n) - weightSum = Array.ofDim[Double](n) - nnz = Array.ofDim[Long](n) - currMax = Array.fill[Double](n)(Double.MinValue) - currMin = Array.fill[Double](n)(Double.MaxValue) + if(testMask(currMeanMask)) currMean = Array.ofDim[Double](n) + if(testMask(currM2nMask)) currM2n = Array.ofDim[Double](n) + if(testMask(currM2Mask)) currM2 = Array.ofDim[Double](n) + if(testMask(currL1Mask)) currL1 = Array.ofDim[Double](n) + if(testMask(weightSumMask)) weightSum = Array.ofDim[Double](n) + if(testMask(nnzMask)) nnz = Array.ofDim[Long](n) + if(testMask(currMaxMask)) currMax = Array.fill[Double](n)(Double.MinValue) + if(testMask(currMinMask)) currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == instance.size, s"Dimensions mismatch when adding new sample." + @@ -94,22 +102,26 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrMin = currMin instance.foreachActive { (index, value) => if (value != 0.0) { - if (localCurrMax(index) < value) { + if (testMask(currMaxMask) && localCurrMax(index) < value) { localCurrMax(index) = value } - if (localCurrMin(index) > value) { + if (testMask(currMinMask) && localCurrMin(index) > value) { localCurrMin(index) = value } - val prevMean = localCurrMean(index) - val diff = value - prevMean - localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight) - localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff - localCurrM2(index) += weight * value * value - localCurrL1(index) += weight * math.abs(value) + if (testMask(currMeanMask)) { + val prevMean = localCurrMean(index) + val diff = value - prevMean + localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight) + if (testMask(currM2nMask)) { + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + } + } + if(testMask(currM2Mask)) localCurrM2(index) += weight * value * value + if(testMask(currL1Mask)) localCurrL1(index) += weight * math.abs(value) - localWeightSum(index) += weight - localNumNonzeros(index) += 1 + if(testMask(weightSumMask)) localWeightSum(index) += weight + if(testMask(nnzMask)) localNumNonzeros(index) += 1 } } @@ -136,41 +148,52 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = weightSum(i) - val otherNnz = other.weightSum(i) - val totalNnz = thisNnz + otherNnz - val totalCnnz = nnz(i) + other.nnz(i) - if (totalNnz != 0.0) { - val deltaMean = other.currMean(i) - currMean(i) - // merge mean together - currMean(i) += deltaMean * otherNnz / totalNnz - // merge m2n together - currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz - // merge m2 together - currM2(i) += other.currM2(i) - // merge l1 together - currL1(i) += other.currL1(i) - // merge max and min - currMax(i) = math.max(currMax(i), other.currMax(i)) - currMin(i) = math.min(currMin(i), other.currMin(i)) + if (testMask(weightSumMask)) { + val thisWeightSum = weightSum(i) + val otherWeightSum = other.weightSum(i) + val totalWeightSum = thisWeightSum + otherWeightSum + if (totalWeightSum != 0.0) { + if (testMask(currMeanMask)) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherWeightSum / totalWeightSum + // merge m2n together + if (testMask(currM2nMask)) { + currM2n(i) += other.currM2n(i) + + deltaMean * deltaMean * thisWeightSum * otherWeightSum / totalWeightSum + } + } + // merge m2 together + if(testMask(currM2Mask)) currM2(i) += other.currM2(i) + // merge l1 together + if(testMask(currL1Mask)) currL1(i) += other.currL1(i) + } + weightSum(i) = totalWeightSum + } + if (testMask(nnzMask)) { + val totalNnz = nnz(i) + other.nnz(i) + if (totalNnz != 0) { + // merge max and min + if (testMask(currMaxMask)) currMax(i) = math.max(currMax(i), other.currMax(i)) + if (testMask(currMinMask)) currMin(i) = math.min(currMin(i), other.currMin(i)) + } + nnz(i) = totalNnz } - weightSum(i) = totalNnz - nnz(i) = totalCnnz i += 1 } } else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) { this.n = other.n - this.currMean = other.currMean.clone() - this.currM2n = other.currM2n.clone() - this.currM2 = other.currM2.clone() - this.currL1 = other.currL1.clone() + if (testMask(currMeanMask)) this.currMean = other.currMean.clone() + if (testMask(currM2nMask)) this.currM2n = other.currM2n.clone() + if (testMask(currM2Mask)) this.currM2 = other.currM2.clone() + if (testMask(currL1Mask)) this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum - this.weightSum = other.weightSum.clone() - this.nnz = other.nnz.clone() - this.currMax = other.currMax.clone() - this.currMin = other.currMin.clone() + if (testMask(weightSumMask)) this.weightSum = other.weightSum.clone() + if (testMask(nnzMask)) this.nnz = other.nnz.clone() + if (testMask(currMaxMask)) this.currMax = other.currMax.clone() + if (testMask(currMinMask)) this.currMin = other.currMin.clone() } this } @@ -298,3 +321,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currL1) } } + +object MultivariateOnlineSummarizer { + + private val currMeanMask = 0x1 + private val currM2nMask = 0x2 + private val currM2Mask = 0x4 + private val currL1Mask = 0x8 + private val weightSumMask = 0x10 + private val nnzMask = 0x20 + private val currMaxMask = 0x40 + private val currMinMask = 0x80 + + val meanMask = currMeanMask | weightSumMask + val varianceMask = currMeanMask | currM2nMask | weightSumMask + val numNonZerosMask = nnzMask + val maxMask = nnzMask | currMaxMask + val minMask = nnzMask | currMinMask + + val allMask = 0xFFFFFFFF +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 797e84fcc737..ce548c8e98e3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -270,4 +270,34 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } + + test("mask test") { + import MultivariateOnlineSummarizer._ + + assert(new MultivariateOnlineSummarizer(meanMask) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)).mean + ~== Vectors.dense(1.0, -1.5, 3.0) absTol 1E-5, "mean mismatch") + + assert(new MultivariateOnlineSummarizer(minMask) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)).min + ~== Vectors.dense(-1.0, -3, 0.0) absTol 1E-5, "min mismatch") + + assert(new MultivariateOnlineSummarizer(maxMask) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)).max + ~== Vectors.dense(3.0, 0.0, 6.0) absTol 1E-5, "max mismatch") + + assert(new MultivariateOnlineSummarizer(numNonZerosMask) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)).numNonzeros + ~== Vectors.dense(2, 1, 1) absTol 1E-5, "numNonzeros mismatch") + + assert(new MultivariateOnlineSummarizer(varianceMask) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)).variance + ~== Vectors.dense(8.0, 4.5, 18.0) absTol 1E-5, "variance mismatch") + } + }