Skip to content

Commit 9c92730

Browse files
committed
init pr
1 parent 77d046e commit 9c92730

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ private[ml] object SummaryBuilderImpl extends Logging {
438438
while (i < len) {
439439
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
440440
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
441+
// Because of numerical error, it is possible to get negative real variance
442+
if (realVariance(i) < 0.0) {
443+
realVariance(i) = 0.0
444+
}
441445
i += 1
442446
}
443447
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
215215
while (i < len) {
216216
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
217217
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
218+
// Because of numerical error, it is possible to get negative real variance
219+
if (realVariance(i) < 0.0) {
220+
realVariance(i) = 0.0
221+
}
218222
i += 1
219223
}
220224
}

mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
402402
assert(summarizer.count === 6)
403403
}
404404

405+
test("summarizer buffer zero variance test (SPARK-21818)") {
406+
val summarizer1 = new SummarizerBuffer()
407+
.add(Vectors.dense(3.0), 0.7)
408+
val summarizer2 = new SummarizerBuffer()
409+
.add(Vectors.dense(3.0), 0.4)
410+
val summarizer3 = new SummarizerBuffer()
411+
.add(Vectors.dense(3.0), 0.5)
412+
val summarizer4 = new SummarizerBuffer()
413+
.add(Vectors.dense(3.0), 0.4)
414+
415+
val summarizer = summarizer1
416+
.merge(summarizer2)
417+
.merge(summarizer3)
418+
.merge(summarizer4)
419+
420+
assert(summarizer.variance(0) >= 0.0)
421+
}
422+
405423
test("summarizer buffer merging summarizer with empty summarizer") {
406424
// If one of two is non-empty, this should return the non-empty summarizer.
407425
// If both of them are empty, then just return the empty summarizer.

mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
270270
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
271271
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
272272
}
273+
274+
test ("test zero variance (SPARK-21818)") {
275+
val summarizer1 = (new MultivariateOnlineSummarizer)
276+
.add(Vectors.dense(3.0), 0.7)
277+
val summarizer2 = (new MultivariateOnlineSummarizer)
278+
.add(Vectors.dense(3.0), 0.4)
279+
val summarizer3 = (new MultivariateOnlineSummarizer)
280+
.add(Vectors.dense(3.0), 0.5)
281+
val summarizer4 = (new MultivariateOnlineSummarizer)
282+
.add(Vectors.dense(3.0), 0.4)
283+
284+
val summarizer = summarizer1
285+
.merge(summarizer2)
286+
.merge(summarizer3)
287+
.merge(summarizer4)
288+
289+
assert(summarizer.variance(0) >= 0.0)
290+
}
273291
}

0 commit comments

Comments
 (0)