diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 5ab5b4fc06ec..9a8d7c2e78f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -26,67 +26,31 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary /** - * Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements - * count. + * Column statistics aggregator implementing + * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]] + * together with add() and merge() function. + * A numerically stable algorithm is implemented to compute sample mean and variance: + *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. + * Zero elements (including explicit zero values) are skipped when calling add() and merge(), + * to have time complexity O(nnz) instead of O(n) for each column. */ -trait VectorRDDStatisticalSummary { +private class ColumnStatisticsAggregator(private val n: Int) + extends MultivariateStatisticalSummary with Serializable { - /** - * Computes the mean of columns in RDD[Vector]. - */ - def mean: Vector - - /** - * Computes the sample variance of columns in RDD[Vector]. - */ - def variance: Vector - - /** - * Computes number of vectors in RDD[Vector]. - */ - def count: Long - - /** - * Computes the number of non-zero elements in each column of RDD[Vector]. - */ - def numNonZeros: Vector - - /** - * Computes the maximum of each column in RDD[Vector]. - */ - def max: Vector - - /** - * Computes the minimum of each column in RDD[Vector]. - */ - def min: Vector -} - - -/** - * Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary - * VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution - * used in add() function, while parallel variance solution used in merge() function. Reference here - * : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution - * here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm - * to O(nnz). Real variance is computed here after we get other statistics, simply by another - * parallel combination process. - */ -private class VectorRDDStatisticsAggregator( - val currMean: BDV[Double], - val currM2n: BDV[Double], - var totalCnt: Double, - val nnz: BDV[Double], - val currMax: BDV[Double], - val currMin: BDV[Double]) - extends VectorRDDStatisticalSummary with Serializable { + private val currMean: BDV[Double] = BDV.zeros[Double](n) + private val currM2n: BDV[Double] = BDV.zeros[Double](n) + private var totalCnt: Double = 0.0 + private val nnz: BDV[Double] = BDV.zeros[Double](n) + private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue) + private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue) override def mean = { - val realMean = BDV.zeros[Double](currMean.length) + val realMean = BDV.zeros[Double](n) var i = 0 - while (i < currMean.length) { + while (i < n) { realMean(i) = currMean(i) * nnz(i) / totalCnt i += 1 } @@ -94,7 +58,7 @@ private class VectorRDDStatisticsAggregator( } override def variance = { - val realVariance = BDV.zeros[Double](currM2n.length) + val realVariance = BDV.zeros[Double](n) val denominator = totalCnt - 1.0 @@ -115,12 +79,12 @@ private class VectorRDDStatisticsAggregator( override def count: Long = totalCnt.toLong - override def numNonZeros: Vector = Vectors.fromBreeze(nnz) + override def numNonzeros: Vector = Vectors.fromBreeze(nnz) override def max: Vector = { var i = 0 - while (i < nnz.length) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.fromBreeze(currMax) @@ -128,7 +92,7 @@ private class VectorRDDStatisticsAggregator( override def min: Vector = { var i = 0 - while (i < nnz.length) { + while (i < n) { if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } @@ -136,21 +100,20 @@ private class VectorRDDStatisticsAggregator( } /** - * Aggregate function used for aggregating elements in a worker together. + * Aggregates a row. */ def add(currData: BV[Double]): this.type = { currData.activeIterator.foreach { - // this case is used for filtering the zero elements if the vector. - case (id, 0.0) => - case (id, value) => - if (currMax(id) < value) currMax(id) = value - if (currMin(id) > value) currMin(id) = value + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => + if (currMax(i) < value) currMax(i) = value + if (currMin(i) > value) currMin(i) = value - val tmpPrevMean = currMean(id) - currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0) - currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean) + val tmpPrevMean = currMean(i) + currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - nnz(id) += 1.0 + nnz(i) += 1.0 } totalCnt += 1.0 @@ -158,16 +121,18 @@ private class VectorRDDStatisticsAggregator( } /** - * Combine function used for combining intermediate results together from every worker. + * Merges another aggregator. */ - def merge(other: VectorRDDStatisticsAggregator): this.type = { + def merge(other: ColumnStatisticsAggregator): this.type = { + + require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.") totalCnt += other.totalCnt val deltaMean = currMean - other.currMean var i = 0 - while (i < other.currMean.length) { + while (i < n) { // merge mean together if (other.currMean(i) != 0.0) { currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / @@ -188,6 +153,7 @@ private class VectorRDDStatisticsAggregator( } nnz += other.nnz + this } } @@ -344,13 +310,7 @@ class RowMatrix( combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) ) - // Update _m if it is not set, or verify its value. - if (nRows <= 0L) { - nRows = m - } else { - require(nRows == m, - s"The number of rows $m is different from what specified or previously computed: ${nRows}.") - } + updateNumRows(m) mean :/= m.toDouble @@ -403,21 +363,16 @@ class RowMatrix( } /** - * Compute full column-wise statistics for the RDD with the size of Vector as input parameter. + * Computes column-wise summary statistics. */ - def multiVariateSummaryStatistics(): VectorRDDStatisticalSummary = { - val zeroValue = new VectorRDDStatisticsAggregator( - BDV.zeros[Double](nCols), - BDV.zeros[Double](nCols), - 0.0, - BDV.zeros[Double](nCols), - BDV.fill(nCols)(Double.MinValue), - BDV.fill(nCols)(Double.MaxValue)) - - rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)( + def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { + val zeroValue = new ColumnStatisticsAggregator(numCols().toInt) + val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2) ) + updateNumRows(summary.count) + summary } /** @@ -456,6 +411,27 @@ class RowMatrix( } mat } + + /** Updates or verifies the number of columns. */ + private def updateNumCols(n: Int) { + if (nCols <= 0) { + nCols == n + } else { + require(nCols == n, + s"The number of columns $n is different from " + + s"what specified or previously computed: ${nCols}.") + } + } + + /** Updates or verfires the number of rows. */ + private def updateNumRows(m: Long) { + if (nRows <= 0) { + nRows == m + } else { + require(nRows == m, + s"The number of rows $m is different from what specified or previously computed: ${nRows}.") + } + } } object RowMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala new file mode 100644 index 000000000000..f9eb343da2b8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.mllib.linalg.Vector + +/** + * Trait for multivariate statistical summary of a data matrix. + */ +trait MultivariateStatisticalSummary { + + /** + * Sample mean vector. + */ + def mean: Vector + + /** + * Sample variance vector. Should return a zero vector if the sample size is 1. + */ + def variance: Vector + + /** + * Sample size. + */ + def count: Long + + /** + * Number of nonzero elements (including explicitly presented zero values) in each column. + */ + def numNonzeros: Vector + + /** + * Maximum value of each column. + */ + def max: Vector + + /** + * Minimum value of each column. + */ + def min: Vector +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 19c8a7730cb0..c9f9acf4c133 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -137,9 +137,6 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { brzNorm(v, 1.0) < 1e-6 } - def equivVector(lhs: Vector, rhs: Vector): Boolean = - closeToZero(lhs.toBreeze.asInstanceOf[BDV[Double]] - rhs.toBreeze.asInstanceOf[BDV[Double]]) - def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) { assert(A.rows === B.rows) for (j <- 0 until k) { @@ -174,45 +171,18 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { } } - test("dense statistical summary") { - val summary = denseMat.multiVariateSummaryStatistics() - - assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)), - "Dense column mean do not match.") - - assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)), - "Dense column variance do not match.") - - assert(summary.count === 4, "Dense column cnt do not match.") - - assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)), - "Dense column nnz do not match.") - - assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)), - "Dense column max do not match.") - - assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)), - "Dense column min do not match.") - } - - test("sparse statistical summary") { - val summary = sparseMat.multiVariateSummaryStatistics() - - assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)), - "Sparse column mean do not match.") - - assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)), - "Sparse column variance do not match.") - - assert(summary.count === 4, "Sparse column cnt do not match.") - - assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)), - "Sparse column nnz do not match.") - - assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)), - "Sparse column max do not match.") - - assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)), - "Sparse column min do not match.") + test("compute column summary statistics") { + for (mat <- Seq(denseMat, sparseMat)) { + val summary = mat.computeColumnSummaryStatistics() + // Run twice to make sure no internal states are changed. + for (k <- 0 to 1) { + assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch") + assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch") + assert(summary.count === m, "count mismatch.") + assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") + assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") + assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + } + } } }