Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,75 +26,39 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary

/**
* Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements
* count.
* Column statistics aggregator implementing
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
* together with add() and merge() function.
* A numerically stable algorithm is implemented to compute sample mean and variance:
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
trait VectorRDDStatisticalSummary {
private class ColumnStatisticsAggregator(private val n: Int)
extends MultivariateStatisticalSummary with Serializable {

/**
* Computes the mean of columns in RDD[Vector].
*/
def mean: Vector

/**
* Computes the sample variance of columns in RDD[Vector].
*/
def variance: Vector

/**
* Computes number of vectors in RDD[Vector].
*/
def count: Long

/**
* Computes the number of non-zero elements in each column of RDD[Vector].
*/
def numNonZeros: Vector

/**
* Computes the maximum of each column in RDD[Vector].
*/
def max: Vector

/**
* Computes the minimum of each column in RDD[Vector].
*/
def min: Vector
}


/**
* Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary
* VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution
* used in add() function, while parallel variance solution used in merge() function. Reference here
* : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution
* here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm
* to O(nnz). Real variance is computed here after we get other statistics, simply by another
* parallel combination process.
*/
private class VectorRDDStatisticsAggregator(
val currMean: BDV[Double],
val currM2n: BDV[Double],
var totalCnt: Double,
val nnz: BDV[Double],
val currMax: BDV[Double],
val currMin: BDV[Double])
extends VectorRDDStatisticalSummary with Serializable {
private val currMean: BDV[Double] = BDV.zeros[Double](n)
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
private var totalCnt: Double = 0.0
private val nnz: BDV[Double] = BDV.zeros[Double](n)
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)

override def mean = {
val realMean = BDV.zeros[Double](currMean.length)
val realMean = BDV.zeros[Double](n)
var i = 0
while (i < currMean.length) {
while (i < n) {
realMean(i) = currMean(i) * nnz(i) / totalCnt
i += 1
}
Vectors.fromBreeze(realMean)
}

override def variance = {
val realVariance = BDV.zeros[Double](currM2n.length)
val realVariance = BDV.zeros[Double](n)

val denominator = totalCnt - 1.0

Expand All @@ -115,59 +79,60 @@ private class VectorRDDStatisticsAggregator(

override def count: Long = totalCnt.toLong

override def numNonZeros: Vector = Vectors.fromBreeze(nnz)
override def numNonzeros: Vector = Vectors.fromBreeze(nnz)

override def max: Vector = {
var i = 0
while (i < nnz.length) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}

override def min: Vector = {
var i = 0
while (i < nnz.length) {
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}

/**
* Aggregate function used for aggregating elements in a worker together.
* Aggregates a row.
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
// this case is used for filtering the zero elements if the vector.
case (id, 0.0) =>
case (id, value) =>
if (currMax(id) < value) currMax(id) = value
if (currMin(id) > value) currMin(id) = value
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) currMax(i) = value
if (currMin(i) > value) currMin(i) = value

val tmpPrevMean = currMean(id)
currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0)
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)

nnz(id) += 1.0
nnz(i) += 1.0
}

totalCnt += 1.0
this
}

/**
* Combine function used for combining intermediate results together from every worker.
* Merges another aggregator.
*/
def merge(other: VectorRDDStatisticsAggregator): this.type = {
def merge(other: ColumnStatisticsAggregator): this.type = {

require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")

totalCnt += other.totalCnt

val deltaMean = currMean - other.currMean

var i = 0
while (i < other.currMean.length) {
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
Expand All @@ -188,6 +153,7 @@ private class VectorRDDStatisticsAggregator(
}

nnz += other.nnz

this
}
}
Expand Down Expand Up @@ -344,13 +310,7 @@ class RowMatrix(
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
)

// Update _m if it is not set, or verify its value.
if (nRows <= 0L) {
nRows = m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
updateNumRows(m)

mean :/= m.toDouble

Expand Down Expand Up @@ -403,21 +363,16 @@ class RowMatrix(
}

/**
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
* Computes column-wise summary statistics.
*/
def multiVariateSummaryStatistics(): VectorRDDStatisticalSummary = {
val zeroValue = new VectorRDDStatisticsAggregator(
BDV.zeros[Double](nCols),
BDV.zeros[Double](nCols),
0.0,
BDV.zeros[Double](nCols),
BDV.fill(nCols)(Double.MinValue),
BDV.fill(nCols)(Double.MaxValue))

rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)(
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
updateNumRows(summary.count)
summary
}

/**
Expand Down Expand Up @@ -456,6 +411,27 @@ class RowMatrix(
}
mat
}

/** Updates or verifies the number of columns. */
private def updateNumCols(n: Int) {
if (nCols <= 0) {
nCols == n
} else {
require(nCols == n,
s"The number of columns $n is different from " +
s"what specified or previously computed: ${nCols}.")
}
}

/** Updates or verfires the number of rows. */
private def updateNumRows(m: Long) {
if (nRows <= 0) {
nRows == m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
}
}

object RowMatrix {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat

import org.apache.spark.mllib.linalg.Vector

/**
* Trait for multivariate statistical summary of a data matrix.
*/
trait MultivariateStatisticalSummary {

/**
* Sample mean vector.
*/
def mean: Vector

/**
* Sample variance vector. Should return a zero vector if the sample size is 1.
*/
def variance: Vector

/**
* Sample size.
*/
def count: Long

/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
def numNonzeros: Vector

/**
* Maximum value of each column.
*/
def max: Vector

/**
* Minimum value of each column.
*/
def min: Vector
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
brzNorm(v, 1.0) < 1e-6
}

def equivVector(lhs: Vector, rhs: Vector): Boolean =
closeToZero(lhs.toBreeze.asInstanceOf[BDV[Double]] - rhs.toBreeze.asInstanceOf[BDV[Double]])

def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) {
assert(A.rows === B.rows)
for (j <- 0 until k) {
Expand Down Expand Up @@ -174,45 +171,18 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
}
}

test("dense statistical summary") {
val summary = denseMat.multiVariateSummaryStatistics()

assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)),
"Dense column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)),
"Dense column variance do not match.")

assert(summary.count === 4, "Dense column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)),
"Dense column nnz do not match.")

assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)),
"Dense column max do not match.")

assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)),
"Dense column min do not match.")
}

test("sparse statistical summary") {
val summary = sparseMat.multiVariateSummaryStatistics()

assert(equivVector(summary.mean, Vectors.dense(4.5, 3.0, 4.0)),
"Sparse column mean do not match.")

assert(equivVector(summary.variance, Vectors.dense(15.0, 10.0, 10.0)),
"Sparse column variance do not match.")

assert(summary.count === 4, "Sparse column cnt do not match.")

assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 4.0)),
"Sparse column nnz do not match.")

assert(equivVector(summary.max, Vectors.dense(9.0, 7.0, 8.0)),
"Sparse column max do not match.")

assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 1.0)),
"Sparse column min do not match.")
test("compute column summary statistics") {
for (mat <- Seq(denseMat, sparseMat)) {
val summary = mat.computeColumnSummaryStatistics()
// Run twice to make sure no internal states are changed.
for (k <- 0 to 1) {
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
assert(summary.count === m, "count mismatch.")
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
}
}
}
}