Skip to content

Commit cbbefdb

Browse files
mengxryinxusen
authored andcommitted
update multivariate statistical summary interface and clean tests
1 parent 4eaf28a commit cbbefdb

File tree

3 files changed

+135
-133
lines changed

3 files changed

+135
-133
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 66 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -27,75 +27,39 @@ import org.apache.spark.annotation.Experimental
2727
import org.apache.spark.mllib.linalg._
2828
import org.apache.spark.rdd.RDD
2929
import org.apache.spark.Logging
30+
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
3031

3132
/**
32-
* Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements
33-
* count.
33+
* Column statistics aggregator implementing
34+
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
35+
* together with add() and merge() function.
36+
* A numerically stable algorithm is implemented to compute sample mean and variance:
37+
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
38+
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
39+
* to have time complexity O(nnz) instead of O(n) for each column.
3440
*/
35-
trait VectorRDDStatisticalSummary {
41+
private class ColumnStatisticsAggregator(private val n: Int)
42+
extends MultivariateStatisticalSummary with Serializable {
3643

37-
/**
38-
* Computes the mean of columns in RDD[Vector].
39-
*/
40-
def mean: Vector
41-
42-
/**
43-
* Computes the sample variance of columns in RDD[Vector].
44-
*/
45-
def variance: Vector
46-
47-
/**
48-
* Computes number of vectors in RDD[Vector].
49-
*/
50-
def count: Long
51-
52-
/**
53-
* Computes the number of non-zero elements in each column of RDD[Vector].
54-
*/
55-
def numNonZeros: Vector
56-
57-
/**
58-
* Computes the maximum of each column in RDD[Vector].
59-
*/
60-
def max: Vector
61-
62-
/**
63-
* Computes the minimum of each column in RDD[Vector].
64-
*/
65-
def min: Vector
66-
}
67-
68-
69-
/**
70-
* Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary
71-
* VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution
72-
* used in add() function, while parallel variance solution used in merge() function. Reference here
73-
* : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution
74-
* here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm
75-
* to O(nnz). Real variance is computed here after we get other statistics, simply by another
76-
* parallel combination process.
77-
*/
78-
private class VectorRDDStatisticsAggregator(
79-
val currMean: BDV[Double],
80-
val currM2n: BDV[Double],
81-
var totalCnt: Double,
82-
val nnz: BDV[Double],
83-
val currMax: BDV[Double],
84-
val currMin: BDV[Double])
85-
extends VectorRDDStatisticalSummary with Serializable {
44+
private val currMean: BDV[Double] = BDV.zeros[Double](n)
45+
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
46+
private var totalCnt: Double = 0.0
47+
private val nnz: BDV[Double] = BDV.zeros[Double](n)
48+
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
49+
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
8650

8751
override def mean = {
88-
val realMean = BDV.zeros[Double](currMean.length)
52+
val realMean = BDV.zeros[Double](n)
8953
var i = 0
90-
while (i < currMean.length) {
54+
while (i < n) {
9155
realMean(i) = currMean(i) * nnz(i) / totalCnt
9256
i += 1
9357
}
9458
Vectors.fromBreeze(realMean)
9559
}
9660

9761
override def variance = {
98-
val realVariance = BDV.zeros[Double](currM2n.length)
62+
val realVariance = BDV.zeros[Double](n)
9963

10064
val denominator = totalCnt - 1.0
10165

@@ -116,59 +80,60 @@ private class VectorRDDStatisticsAggregator(
11680

11781
override def count: Long = totalCnt.toLong
11882

119-
override def numNonZeros: Vector = Vectors.fromBreeze(nnz)
83+
override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
12084

12185
override def max: Vector = {
12286
var i = 0
123-
while (i < nnz.length) {
124-
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
87+
while (i < n) {
88+
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
12589
i += 1
12690
}
12791
Vectors.fromBreeze(currMax)
12892
}
12993

13094
override def min: Vector = {
13195
var i = 0
132-
while (i < nnz.length) {
96+
while (i < n) {
13397
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
13498
i += 1
13599
}
136100
Vectors.fromBreeze(currMin)
137101
}
138102

139103
/**
140-
* Aggregate function used for aggregating elements in a worker together.
104+
* Aggregates a row.
141105
*/
142106
def add(currData: BV[Double]): this.type = {
143107
currData.activeIterator.foreach {
144-
// this case is used for filtering the zero elements if the vector.
145-
case (id, 0.0) =>
146-
case (id, value) =>
147-
if (currMax(id) < value) currMax(id) = value
148-
if (currMin(id) > value) currMin(id) = value
108+
case (_, 0.0) => // Skip explicit zero elements.
109+
case (i, value) =>
110+
if (currMax(i) < value) currMax(i) = value
111+
if (currMin(i) > value) currMin(i) = value
149112

150-
val tmpPrevMean = currMean(id)
151-
currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0)
152-
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
113+
val tmpPrevMean = currMean(i)
114+
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
115+
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
153116

154-
nnz(id) += 1.0
117+
nnz(i) += 1.0
155118
}
156119

157120
totalCnt += 1.0
158121
this
159122
}
160123

161124
/**
162-
* Combine function used for combining intermediate results together from every worker.
125+
* Merges another aggregator.
163126
*/
164-
def merge(other: VectorRDDStatisticsAggregator): this.type = {
127+
def merge(other: ColumnStatisticsAggregator): this.type = {
128+
129+
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
165130

166131
totalCnt += other.totalCnt
167132

168133
val deltaMean = currMean - other.currMean
169134

170135
var i = 0
171-
while (i < other.currMean.length) {
136+
while (i < n) {
172137
// merge mean together
173138
if (other.currMean(i) != 0.0) {
174139
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
@@ -189,6 +154,7 @@ private class VectorRDDStatisticsAggregator(
189154
}
190155

191156
nnz += other.nnz
157+
192158
this
193159
}
194160
}
@@ -346,13 +312,7 @@ class RowMatrix(
346312
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
347313
)
348314

349-
// Update _m if it is not set, or verify its value.
350-
if (nRows <= 0L) {
351-
nRows = m
352-
} else {
353-
require(nRows == m,
354-
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
355-
}
315+
updateNumRows(m)
356316

357317
mean :/= m.toDouble
358318

@@ -405,21 +365,16 @@ class RowMatrix(
405365
}
406366

407367
/**
408-
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
368+
* Computes column-wise summary statistics.
409369
*/
410-
def multiVariateSummaryStatistics(): VectorRDDStatisticalSummary = {
411-
val zeroValue = new VectorRDDStatisticsAggregator(
412-
BDV.zeros[Double](nCols),
413-
BDV.zeros[Double](nCols),
414-
0.0,
415-
BDV.zeros[Double](nCols),
416-
BDV.fill(nCols)(Double.MinValue),
417-
BDV.fill(nCols)(Double.MaxValue))
418-
419-
rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)(
370+
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
371+
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
372+
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
420373
(aggregator, data) => aggregator.add(data),
421374
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
422375
)
376+
updateNumRows(summary.count)
377+
summary
423378
}
424379

425380
/**
@@ -458,6 +413,27 @@ class RowMatrix(
458413
}
459414
mat
460415
}
416+
417+
/** Updates or verifies the number of columns. */
418+
private def updateNumCols(n: Int) {
419+
if (nCols <= 0) {
420+
nCols == n
421+
} else {
422+
require(nCols == n,
423+
s"The number of columns $n is different from " +
424+
s"what specified or previously computed: ${nCols}.")
425+
}
426+
}
427+
428+
/** Updates or verfires the number of rows. */
429+
private def updateNumRows(m: Long) {
430+
if (nRows <= 0) {
431+
nRows == m
432+
} else {
433+
require(nRows == m,
434+
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
435+
}
436+
}
461437
}
462438

463439
object RowMatrix {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.stat
19+
20+
import org.apache.spark.mllib.linalg.Vector
21+
22+
/**
23+
* Trait for multivariate statistical summary of a data matrix.
24+
*/
25+
trait MultivariateStatisticalSummary {
26+
27+
/**
28+
* Sample mean vector.
29+
*/
30+
def mean: Vector
31+
32+
/**
33+
* Sample variance vector. Should return a zero vector if the sample size is 1.
34+
*/
35+
def variance: Vector
36+
37+
/**
38+
* Sample size.
39+
*/
40+
def count: Long
41+
42+
/**
43+
* Number of nonzero elements (including explicitly presented zero values) in each column.
44+
*/
45+
def numNonzeros: Vector
46+
47+
/**
48+
* Maximum value of each column.
49+
*/
50+
def max: Vector
51+
52+
/**
53+
* Minimum value of each column.
54+
*/
55+
def min: Vector
56+
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala

Lines changed: 13 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,6 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
137137
brzNorm(v, 1.0) < 1e-6
138138
}
139139

140-
def equivVector(lhs: Vector, rhs: Vector): Boolean =
141-
closeToZero(lhs.toBreeze.asInstanceOf[BDV[Double]] - rhs.toBreeze.asInstanceOf[BDV[Double]])
142-
143140
def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) {
144141
assert(A.rows === B.rows)
145142
for (j <- 0 until k) {
@@ -174,45 +171,18 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
174171
}
175172
}
176173

177-
test("dense statistical summary") {
178-
val summary = denseMat.multiVariateSummaryStatistics()
179-
180-
assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)),
181-
"Dense column mean do not match.")
182-
183-
assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)),
184-
"Dense column variance do not match.")
185-
186-
assert(summary.count === 4, "Dense column cnt do not match.")
187-
188-
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)),
189-
"Dense column nnz do not match.")
190-
191-
assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)),
192-
"Dense column max do not match.")
193-
194-
assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)),
195-
"Dense column min do not match.")
196-
}
197-
198-
test("sparse statistical summary") {
199-
val summary = sparseMat.multiVariateSummaryStatistics()
200-
201-
assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)),
202-
"Sparse column mean do not match.")
203-
204-
assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)),
205-
"Sparse column variance do not match.")
206-
207-
assert(summary.count === 4, "Sparse column cnt do not match.")
208-
209-
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)),
210-
"Sparse column nnz do not match.")
211-
212-
assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)),
213-
"Sparse column max do not match.")
214-
215-
assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)),
216-
"Sparse column min do not match.")
174+
test("compute column summary statistics") {
175+
for (mat <- Seq(denseMat, sparseMat)) {
176+
val summary = mat.computeColumnSummaryStatistics()
177+
// Run twice to make sure no internal states are changed.
178+
for (k <- 0 to 1) {
179+
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
180+
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
181+
assert(summary.count === m, "count mismatch.")
182+
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
183+
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
184+
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
185+
}
186+
}
217187
}
218188
}

0 commit comments

Comments
 (0)