Skip to content

Commit c24292c

Browse files
committed
update
1 parent 9c92730 commit c24292c

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ private[ml] object WeightedLeastSquares {
440440
/**
441441
* Weighted population standard deviation of labels.
442442
*/
443-
def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
443+
def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0))
444444

445445
/**
446446
* Weighted mean of (label * features).

mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,12 +436,8 @@ private[ml] object SummaryBuilderImpl extends Logging {
436436
var i = 0
437437
val len = currM2n.length
438438
while (i < len) {
439-
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
440-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
441-
// Because of numerical error, it is possible to get negative real variance
442-
if (realVariance(i) < 0.0) {
443-
realVariance(i) = 0.0
444-
}
439+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
440+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
445441
i += 1
446442
}
447443
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
213213
var i = 0
214214
val len = currM2n.length
215215
while (i < len) {
216-
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
217-
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
218-
// Because of numerical error, it is possible to get negative real variance
219-
if (realVariance(i) < 0.0) {
220-
realVariance(i) = 0.0
221-
}
216+
realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
217+
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
222218
i += 1
223219
}
224220
}

0 commit comments

Comments
 (0)