Skip to content

Commit b13ac90

Browse files
dbtsaiDB Tsai
authored andcommitted
dbtsai-summarizer
1 parent f4f46de commit b13ac90

File tree

5 files changed

+458
-134
lines changed

5 files changed

+458
-134
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 2 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -28,138 +28,7 @@ import org.apache.spark.annotation.Experimental
2828
import org.apache.spark.mllib.linalg._
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.Logging
31-
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
32-
33-
/**
34-
* Column statistics aggregator implementing
35-
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
36-
* together with add() and merge() function.
37-
* A numerically stable algorithm is implemented to compute sample mean and variance:
38-
* [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
39-
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
40-
* to have time complexity O(nnz) instead of O(n) for each column.
41-
*/
42-
private class ColumnStatisticsAggregator(private val n: Int)
43-
extends MultivariateStatisticalSummary with Serializable {
44-
45-
private val currMean: BDV[Double] = BDV.zeros[Double](n)
46-
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
47-
private var totalCnt = 0.0
48-
private val nnz: BDV[Double] = BDV.zeros[Double](n)
49-
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
50-
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
51-
52-
override def mean: Vector = {
53-
val realMean = BDV.zeros[Double](n)
54-
var i = 0
55-
while (i < n) {
56-
realMean(i) = currMean(i) * nnz(i) / totalCnt
57-
i += 1
58-
}
59-
Vectors.fromBreeze(realMean)
60-
}
61-
62-
override def variance: Vector = {
63-
val realVariance = BDV.zeros[Double](n)
64-
65-
val denominator = totalCnt - 1.0
66-
67-
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
68-
if (denominator > 0.0) {
69-
val deltaMean = currMean
70-
var i = 0
71-
while (i < currM2n.size) {
72-
realVariance(i) =
73-
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
74-
realVariance(i) /= denominator
75-
i += 1
76-
}
77-
}
78-
79-
Vectors.fromBreeze(realVariance)
80-
}
81-
82-
override def count: Long = totalCnt.toLong
83-
84-
override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
85-
86-
override def max: Vector = {
87-
var i = 0
88-
while (i < n) {
89-
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
90-
i += 1
91-
}
92-
Vectors.fromBreeze(currMax)
93-
}
94-
95-
override def min: Vector = {
96-
var i = 0
97-
while (i < n) {
98-
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
99-
i += 1
100-
}
101-
Vectors.fromBreeze(currMin)
102-
}
103-
104-
/**
105-
* Aggregates a row.
106-
*/
107-
def add(currData: BV[Double]): this.type = {
108-
currData.activeIterator.foreach {
109-
case (_, 0.0) => // Skip explicit zero elements.
110-
case (i, value) =>
111-
if (currMax(i) < value) {
112-
currMax(i) = value
113-
}
114-
if (currMin(i) > value) {
115-
currMin(i) = value
116-
}
117-
118-
val tmpPrevMean = currMean(i)
119-
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
120-
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
121-
122-
nnz(i) += 1.0
123-
}
124-
125-
totalCnt += 1.0
126-
this
127-
}
128-
129-
/**
130-
* Merges another aggregator.
131-
*/
132-
def merge(other: ColumnStatisticsAggregator): this.type = {
133-
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
134-
135-
totalCnt += other.totalCnt
136-
val deltaMean = currMean - other.currMean
137-
138-
var i = 0
139-
while (i < n) {
140-
// merge mean together
141-
if (other.currMean(i) != 0.0) {
142-
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
143-
(nnz(i) + other.nnz(i))
144-
}
145-
// merge m2n together
146-
if (nnz(i) + other.nnz(i) != 0.0) {
147-
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
148-
(nnz(i) + other.nnz(i))
149-
}
150-
if (currMax(i) < other.currMax(i)) {
151-
currMax(i) = other.currMax(i)
152-
}
153-
if (currMin(i) > other.currMin(i)) {
154-
currMin(i) = other.currMin(i)
155-
}
156-
i += 1
157-
}
158-
159-
nnz += other.nnz
160-
this
161-
}
162-
}
31+
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
16332

16433
/**
16534
* :: Experimental ::
@@ -478,8 +347,7 @@ class RowMatrix(
478347
* Computes column-wise summary statistics.
479348
*/
480349
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
481-
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
482-
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
350+
val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)(
483351
(aggregator, data) => aggregator.add(data),
484352
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
485353
)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat
19+
20+
import breeze.linalg.{DenseVector => BDV}
21+
22+
import org.apache.spark.annotation.DeveloperApi
23+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
24+
25+
/**
26+
* :: DeveloperApi ::
27+
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
28+
* variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector
29+
* format in a online fashion.
30+
*
31+
* Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of
32+
* the corresponding joint dataset.
33+
*
34+
* A numerically stable algorithm is implemented to compute sample mean and variance:
35+
* Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]
36+
* Zero elements (including explicit zero values) are skipped when calling add(),
37+
* to have time complexity O(nnz) instead of O(n) for each column.
38+
*/
39+
@DeveloperApi
40+
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
41+
42+
private var n = 0
43+
private var currMean: BDV[Double] = _
44+
private var currM2n: BDV[Double] = _
45+
private var totalCnt: Long = 0
46+
private var nnz: BDV[Double] = _
47+
private var currMax: BDV[Double] = _
48+
private var currMin: BDV[Double] = _
49+
50+
/**
51+
* Add a new sample to this summarizer, and update the statistical summary.
52+
*
53+
* @param sample The sample in dense/sparse vector format to be added into this summarizer.
54+
* @return This MultivariateOnlineSummarizer object.
55+
*/
56+
def add(sample: Vector): this.type = {
57+
if (n == 0) {
58+
require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.")
59+
n = sample.toBreeze.length
60+
61+
currMean = BDV.zeros[Double](n)
62+
currM2n = BDV.zeros[Double](n)
63+
nnz = BDV.zeros[Double](n)
64+
currMax = BDV.fill(n)(Double.MinValue)
65+
currMin = BDV.fill(n)(Double.MaxValue)
66+
}
67+
68+
require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." +
69+
s" Expecting $n but got ${sample.toBreeze.length}.")
70+
71+
sample.toBreeze.activeIterator.foreach {
72+
case (_, 0.0) => // Skip explicit zero elements.
73+
case (i, value) =>
74+
if (currMax(i) < value) {
75+
currMax(i) = value
76+
}
77+
if (currMin(i) > value) {
78+
currMin(i) = value
79+
}
80+
81+
val tmpPrevMean = currMean(i)
82+
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
83+
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
84+
85+
nnz(i) += 1.0
86+
}
87+
88+
totalCnt += 1
89+
this
90+
}
91+
92+
/**
93+
* Merge another MultivariateOnlineSummarizer, and update the statistical summary.
94+
* (Note that it's in place merging; as a result, `this` object will be modified.)
95+
*
96+
* @param other The other MultivariateOnlineSummarizer to be merged.
97+
* @return This MultivariateOnlineSummarizer object.
98+
*/
99+
def merge(other: MultivariateOnlineSummarizer): this.type = {
100+
if (this.totalCnt != 0 && other.totalCnt != 0) {
101+
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
102+
s"Expecting $n but got ${other.n}.")
103+
totalCnt += other.totalCnt
104+
val deltaMean: BDV[Double] = currMean - other.currMean
105+
var i = 0
106+
while (i < n) {
107+
// merge mean together
108+
if (other.currMean(i) != 0.0) {
109+
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
110+
(nnz(i) + other.nnz(i))
111+
}
112+
// merge m2n together
113+
if (nnz(i) + other.nnz(i) != 0.0) {
114+
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
115+
(nnz(i) + other.nnz(i))
116+
}
117+
if (currMax(i) < other.currMax(i)) {
118+
currMax(i) = other.currMax(i)
119+
}
120+
if (currMin(i) > other.currMin(i)) {
121+
currMin(i) = other.currMin(i)
122+
}
123+
i += 1
124+
}
125+
nnz += other.nnz
126+
} else if (totalCnt == 0 && other.totalCnt != 0) {
127+
this.n = other.n
128+
this.currMean = other.currMean.copy
129+
this.currM2n = other.currM2n.copy
130+
this.totalCnt = other.totalCnt
131+
this.nnz = other.nnz.copy
132+
this.currMax = other.currMax.copy
133+
this.currMin = other.currMin.copy
134+
}
135+
this
136+
}
137+
138+
override def mean: Vector = {
139+
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
140+
141+
val realMean = BDV.zeros[Double](n)
142+
var i = 0
143+
while (i < n) {
144+
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
145+
i += 1
146+
}
147+
Vectors.fromBreeze(realMean)
148+
}
149+
150+
override def variance: Vector = {
151+
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
152+
153+
val realVariance = BDV.zeros[Double](n)
154+
155+
val denominator = totalCnt - 1.0
156+
157+
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
158+
if (denominator > 0.0) {
159+
val deltaMean = currMean
160+
var i = 0
161+
while (i < currM2n.size) {
162+
realVariance(i) =
163+
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
164+
realVariance(i) /= denominator
165+
i += 1
166+
}
167+
}
168+
169+
Vectors.fromBreeze(realVariance)
170+
}
171+
172+
override def count: Long = totalCnt
173+
174+
override def numNonzeros: Vector = {
175+
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
176+
177+
Vectors.fromBreeze(nnz)
178+
}
179+
180+
override def max: Vector = {
181+
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
182+
183+
var i = 0
184+
while (i < n) {
185+
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
186+
i += 1
187+
}
188+
Vectors.fromBreeze(currMax)
189+
}
190+
191+
override def min: Vector = {
192+
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
193+
194+
var i = 0
195+
while (i < n) {
196+
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
197+
i += 1
198+
}
199+
Vectors.fromBreeze(currMin)
200+
}
201+
}

0 commit comments

Comments
 (0)