Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
/**
* Weighted population standard deviation of labels.
*/
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
def bStd: Double = {
// We prevent variance from negative value caused by numerical error.
val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
math.sqrt(variance)
}

/**
* Weighted mean of (label * features).
Expand Down Expand Up @@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
// We prevent variance from negative value caused by numerical error.
std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
i += j
j += 1
}
Expand All @@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
variance(l) = aaValues(i) / wSum - aw * aw
// We prevent variance from negative value caused by numerical error.
variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
i += j
j += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
// We prevent variance from negative value caused by numerical error.
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(summarizer.count === 6)
}

test("summarizer buffer zero variance test (SPARK-21818)") {
val summarizer1 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = new SummarizerBuffer()
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}

test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}

test ("test zero variance (SPARK-21818)") {
val summarizer1 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.7)
val summarizer2 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)
val summarizer3 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.5)
val summarizer4 = (new MultivariateOnlineSummarizer)
.add(Vectors.dense(3.0), 0.4)

val summarizer = summarizer1
.merge(summarizer2)
.merge(summarizer3)
.merge(summarizer4)

assert(summarizer.variance(0) >= 0.0)
}
}