@@ -27,75 +27,39 @@ import org.apache.spark.annotation.Experimental
2727import org .apache .spark .mllib .linalg ._
2828import org .apache .spark .rdd .RDD
2929import org .apache .spark .Logging
30+ import org .apache .spark .mllib .stat .MultivariateStatisticalSummary
3031
3132/**
32- * Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements
33- * count.
33+ * Column statistics aggregator implementing
34+ * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary ]]
35+ * together with add() and merge() function.
36+ * A numerically stable algorithm is implemented to compute sample mean and variance:
37+ *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki ]].
38+ * Zero elements (including explicit zero values) are skipped when calling add() and merge(),
39+ * to have time complexity O(nnz) instead of O(n) for each column.
3440 */
35- trait VectorRDDStatisticalSummary {
41+ private class ColumnStatisticsAggregator (private val n : Int )
42+ extends MultivariateStatisticalSummary with Serializable {
3643
37- /**
38- * Computes the mean of columns in RDD[Vector].
39- */
40- def mean : Vector
41-
42- /**
43- * Computes the sample variance of columns in RDD[Vector].
44- */
45- def variance : Vector
46-
47- /**
48- * Computes number of vectors in RDD[Vector].
49- */
50- def count : Long
51-
52- /**
53- * Computes the number of non-zero elements in each column of RDD[Vector].
54- */
55- def numNonZeros : Vector
56-
57- /**
58- * Computes the maximum of each column in RDD[Vector].
59- */
60- def max : Vector
61-
62- /**
63- * Computes the minimum of each column in RDD[Vector].
64- */
65- def min : Vector
66- }
67-
68-
69- /**
70- * Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary
71- * VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution
72- * used in add() function, while parallel variance solution used in merge() function. Reference here
73- * : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki ]]. Solution
74- * here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm
75- * to O(nnz). Real variance is computed here after we get other statistics, simply by another
76- * parallel combination process.
77- */
78- private class VectorRDDStatisticsAggregator (
79- val currMean : BDV [Double ],
80- val currM2n : BDV [Double ],
81- var totalCnt : Double ,
82- val nnz : BDV [Double ],
83- val currMax : BDV [Double ],
84- val currMin : BDV [Double ])
85- extends VectorRDDStatisticalSummary with Serializable {
44+ private val currMean : BDV [Double ] = BDV .zeros[Double ](n)
45+ private val currM2n : BDV [Double ] = BDV .zeros[Double ](n)
46+ private var totalCnt : Double = 0.0
47+ private val nnz : BDV [Double ] = BDV .zeros[Double ](n)
48+ private val currMax : BDV [Double ] = BDV .fill(n)(Double .MinValue )
49+ private val currMin : BDV [Double ] = BDV .fill(n)(Double .MaxValue )
8650
8751 override def mean = {
88- val realMean = BDV .zeros[Double ](currMean.length )
52+ val realMean = BDV .zeros[Double ](n )
8953 var i = 0
90- while (i < currMean.length ) {
54+ while (i < n ) {
9155 realMean(i) = currMean(i) * nnz(i) / totalCnt
9256 i += 1
9357 }
9458 Vectors .fromBreeze(realMean)
9559 }
9660
9761 override def variance = {
98- val realVariance = BDV .zeros[Double ](currM2n.length )
62+ val realVariance = BDV .zeros[Double ](n )
9963
10064 val denominator = totalCnt - 1.0
10165
@@ -116,59 +80,60 @@ private class VectorRDDStatisticsAggregator(
11680
11781 override def count : Long = totalCnt.toLong
11882
119- override def numNonZeros : Vector = Vectors .fromBreeze(nnz)
83+ override def numNonzeros : Vector = Vectors .fromBreeze(nnz)
12084
12185 override def max : Vector = {
12286 var i = 0
123- while (i < nnz.length ) {
124- if ((nnz(i) < totalCnt) && (currMax(i) < 0.0 )) currMax(i) = 0.0
87+ while (i < n ) {
88+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0 )) currMax(i) = 0.0
12589 i += 1
12690 }
12791 Vectors .fromBreeze(currMax)
12892 }
12993
13094 override def min : Vector = {
13195 var i = 0
132- while (i < nnz.length ) {
96+ while (i < n ) {
13397 if ((nnz(i) < totalCnt) && (currMin(i) > 0.0 )) currMin(i) = 0.0
13498 i += 1
13599 }
136100 Vectors .fromBreeze(currMin)
137101 }
138102
139103 /**
140- * Aggregate function used for aggregating elements in a worker together .
104+ * Aggregates a row .
141105 */
142106 def add (currData : BV [Double ]): this .type = {
143107 currData.activeIterator.foreach {
144- // this case is used for filtering the zero elements if the vector.
145- case (id, 0.0 ) =>
146- case (id, value) =>
147- if (currMax(id) < value) currMax(id) = value
148- if (currMin(id) > value) currMin(id) = value
108+ case (_, 0.0 ) => // Skip explicit zero elements.
109+ case (i, value) =>
110+ if (currMax(i) < value) currMax(i) = value
111+ if (currMin(i) > value) currMin(i) = value
149112
150- val tmpPrevMean = currMean(id )
151- currMean(id ) = (currMean(id ) * nnz(id ) + value) / (nnz(id ) + 1.0 )
152- currM2n(id ) += (value - currMean(id )) * (value - tmpPrevMean)
113+ val tmpPrevMean = currMean(i )
114+ currMean(i ) = (currMean(i ) * nnz(i ) + value) / (nnz(i ) + 1.0 )
115+ currM2n(i ) += (value - currMean(i )) * (value - tmpPrevMean)
153116
154- nnz(id ) += 1.0
117+ nnz(i ) += 1.0
155118 }
156119
157120 totalCnt += 1.0
158121 this
159122 }
160123
161124 /**
162- * Combine function used for combining intermediate results together from every worker .
125+ * Merges another aggregator .
163126 */
164- def merge (other : VectorRDDStatisticsAggregator ): this .type = {
127+ def merge (other : ColumnStatisticsAggregator ): this .type = {
128+
129+ require(n == other.n, s " Dimensions mismatch. Expecting $n but got ${other.n}. " )
165130
166131 totalCnt += other.totalCnt
167132
168133 val deltaMean = currMean - other.currMean
169134
170135 var i = 0
171- while (i < other.currMean.length ) {
136+ while (i < n ) {
172137 // merge mean together
173138 if (other.currMean(i) != 0.0 ) {
174139 currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
@@ -189,6 +154,7 @@ private class VectorRDDStatisticsAggregator(
189154 }
190155
191156 nnz += other.nnz
157+
192158 this
193159 }
194160}
@@ -346,13 +312,7 @@ class RowMatrix(
346312 combOp = (s1 : (Long , BDV [Double ]), s2 : (Long , BDV [Double ])) => (s1._1 + s2._1, s1._2 += s2._2)
347313 )
348314
349- // Update _m if it is not set, or verify its value.
350- if (nRows <= 0L ) {
351- nRows = m
352- } else {
353- require(nRows == m,
354- s " The number of rows $m is different from what specified or previously computed: ${nRows}. " )
355- }
315+ updateNumRows(m)
356316
357317 mean :/= m.toDouble
358318
@@ -405,21 +365,16 @@ class RowMatrix(
405365 }
406366
407367 /**
408- * Compute full column-wise statistics for the RDD with the size of Vector as input parameter .
368+ * Computes column-wise summary statistics .
409369 */
410- def multiVariateSummaryStatistics (): VectorRDDStatisticalSummary = {
411- val zeroValue = new VectorRDDStatisticsAggregator (
412- BDV .zeros[Double ](nCols),
413- BDV .zeros[Double ](nCols),
414- 0.0 ,
415- BDV .zeros[Double ](nCols),
416- BDV .fill(nCols)(Double .MinValue ),
417- BDV .fill(nCols)(Double .MaxValue ))
418-
419- rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator ](zeroValue)(
370+ def computeColumnSummaryStatistics (): MultivariateStatisticalSummary = {
371+ val zeroValue = new ColumnStatisticsAggregator (numCols().toInt)
372+ val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator ](zeroValue)(
420373 (aggregator, data) => aggregator.add(data),
421374 (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
422375 )
376+ updateNumRows(summary.count)
377+ summary
423378 }
424379
425380 /**
@@ -458,6 +413,27 @@ class RowMatrix(
458413 }
459414 mat
460415 }
416+
417+ /** Updates or verifies the number of columns. */
418+ private def updateNumCols (n : Int ) {
419+ if (nCols <= 0 ) {
420+ nCols == n
421+ } else {
422+ require(nCols == n,
423+ s " The number of columns $n is different from " +
424+ s " what specified or previously computed: ${nCols}. " )
425+ }
426+ }
427+
428+ /** Updates or verfires the number of rows. */
429+ private def updateNumRows (m : Long ) {
430+ if (nRows <= 0 ) {
431+ nRows == m
432+ } else {
433+ require(nRows == m,
434+ s " The number of rows $m is different from what specified or previously computed: ${nRows}. " )
435+ }
436+ }
461437}
462438
463439object RowMatrix {
0 commit comments