Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currMax: BDV[Double] = _
private var currMin: BDV[Double] = _

/**
* Adds input value to position i.
*/
private[this] def add(i: Int, value: Double) = {
if (value != 0.0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if this a dumb question -- and this isn't a change in this PR -- but why can't a sample of value 0 be added to the summarizer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add it, and get the same result. However, it's computationally cheap if we don't add zero into the summarizer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 affects the mean, and could affect min/max, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. However, we know the total # of samples, and # of nonzero in each column, so if # of samples and # of nonzero are different, and we find the min is some positive number, then the actually min will be zero since we have zero somewhere which we don't add into summarizer.

For max, the same logic will be applied.

For mean, we can fix this effect by realMean(i) = currMean(i) * (nnz(i) / totalCnt)

As a result, for sparse dataset, we only need to add the nonzero into the summarizer, and it will be O(\bar{n}) instead of O(n) where \bar{n} is the average nonzero elements in one sample.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, I get it now.

if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val prevMean = currMean(i)
val diff = value - prevMean
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * diff
currM2(i) += value * value
currL1(i) += math.abs(value)

nnz(i) += 1.0
}
}

/**
* Add a new sample to this summarizer, and update the statistical summary.
*
Expand All @@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")

@inline def update(i: Int, value: Double) = {
if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}

val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
currM2(i) += value * value
currL1(i) += math.abs(value)

nnz(i) += 1.0
}
}

sample match {
case dv: DenseVector => {
var j = 0
while (j < dv.size) {
update(j, dv.values(j))
add(j, dv.values(j))
j += 1
}
}
case sv: SparseVector =>
var j = 0
while (j < sv.indices.size) {
update(sv.indices(j), sv.values(j))
add(sv.indices(j), sv.values(j))
j += 1
}
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
Expand All @@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong because we still need to consider the weight.

while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
// merge m2 together
if (nnz(i) + other.nnz(i) != 0.0) {
val thisNnz = nnz(i)
val otherNnz = other.nnz(i)
val totalNnz = thisNnz + otherNnz
if (totalNnz != 0.0) {
val deltaMean = other.currMean(i) - currMean(i)
// merge mean together
currMean(i) += deltaMean * otherNnz / totalNnz
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. More consistent with the previous notation when we add single sample.

// merge m2n together
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
// merge m2 together
currM2(i) += other.currM2(i)
}
// merge l1 together
if (nnz(i) + other.nnz(i) != 0.0) {
// merge l1 together
currL1(i) += other.currL1(i)
// merge max and min
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}

if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
nnz(i) = totalNnz
i += 1
}
nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
this.currMean = other.currMean.copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {

assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
}

test("merging summarizer when one side has zero mean (SPARK-4355)") {
val s0 = new MultivariateOnlineSummarizer()
.add(Vectors.dense(2.0))
.add(Vectors.dense(2.0))
val s1 = new MultivariateOnlineSummarizer()
.add(Vectors.dense(1.0))
.add(Vectors.dense(-1.0))
s0.merge(s1)
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
}
}