From 9c92730bc3588596b348932ea285b12c5a4a77ce Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 23 Aug 2017 18:52:56 +0800 Subject: [PATCH 1/6] init pr --- .../org/apache/spark/ml/stat/Summarizer.scala | 4 ++++ .../stat/MultivariateOnlineSummarizer.scala | 4 ++++ .../apache/spark/ml/stat/SummarizerSuite.scala | 18 ++++++++++++++++++ .../MultivariateOnlineSummarizerSuite.scala | 18 ++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 7e408b9dbd13..2bce89be3ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -438,6 +438,10 @@ private[ml] object SummaryBuilderImpl extends Logging { while (i < len) { realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + // Because of numerical error, it is possible to get negative real variance + if (realVariance(i) < 0.0) { + realVariance(i) = 0.0 + } i += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7dc0c459ec03..857023731c59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -215,6 +215,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S while (i < len) { realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + // Because of numerical error, it is possible to get negative real variance + if (realVariance(i) < 0.0) { + realVariance(i) = 0.0 + } i += 1 } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala index dfb733ff6e76..1ea851ef2d67 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala @@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(summarizer.count === 6) } + test("summarizer buffer zero variance test (SPARK-21818)") { + val summarizer1 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.7) + val summarizer2 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.4) + val summarizer3 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.5) + val summarizer4 = new SummarizerBuffer() + .add(Vectors.dense(3.0), 0.4) + + val summarizer = summarizer1 + .merge(summarizer2) + .merge(summarizer3) + .merge(summarizer4) + + assert(summarizer.variance(0) >= 0.0) + } + test("summarizer buffer merging summarizer with empty summarizer") { // If one of two is non-empty, this should return the non-empty summarizer. // If both of them are empty, then just return the empty summarizer. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 797e84fcc737..c6466bc918dd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } + + test ("test zero variance (SPARK-21818)") { + val summarizer1 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.7) + val summarizer2 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + val summarizer3 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.5) + val summarizer4 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + + val summarizer = summarizer1 + .merge(summarizer2) + .merge(summarizer3) + .merge(summarizer4) + + assert(summarizer.variance(0) >= 0.0) + } } From c24292ccad700d39892a576390cff2559c4f3b9a Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 25 Aug 2017 13:44:35 +0800 Subject: [PATCH 2/6] update --- .../org/apache/spark/ml/optim/WeightedLeastSquares.scala | 2 +- .../main/scala/org/apache/spark/ml/stat/Summarizer.scala | 8 ++------ .../spark/mllib/stat/MultivariateOnlineSummarizer.scala | 8 ++------ 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 32b0af72ba9b..46da6881dd9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -440,7 +440,7 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. */ - def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0)) /** * Weighted mean of (label * features). diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 2bce89be3ff3..49ae80bba96c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -436,12 +436,8 @@ private[ml] object SummaryBuilderImpl extends Logging { var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator - // Because of numerical error, it is possible to get negative real variance - if (realVariance(i) < 0.0) { - realVariance(i) = 0.0 - } + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 857023731c59..c5bdfd9a6485 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -213,12 +213,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator - // Because of numerical error, it is possible to get negative real variance - if (realVariance(i) < 0.0) { - realVariance(i) = 0.0 - } + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } From 9a47579194f885815b9d298435b7b56a9649da2c Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 25 Aug 2017 17:20:55 +0800 Subject: [PATCH 3/6] add comment --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 + mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 1 + .../apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 1 + 3 files changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 46da6881dd9d..dbf7e7bc07a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -439,6 +439,7 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. + * We prevent `variance` from negative value caused by numerical error. */ def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 49ae80bba96c..d597ecd3eab9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -436,6 +436,7 @@ private[ml] object SummaryBuilderImpl extends Logging { var i = 0 val len = currM2n.length while (i < len) { + // We prevent `variance` from negative value caused by numerical error. realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index c5bdfd9a6485..febe7068eb67 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -213,6 +213,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { + // We prevent `variance` from negative value caused by numerical error. realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 From 56c0d41f1517a49a935464933a8021008d8a32f7 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 25 Aug 2017 17:35:28 +0800 Subject: [PATCH 4/6] update comment --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala | 2 +- .../apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index dbf7e7bc07a0..4ae85c1f73b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -439,7 +439,7 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. - * We prevent `variance` from negative value caused by numerical error. + * We prevent variance from negative value caused by numerical error. */ def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index d597ecd3eab9..cae41edb7aca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -436,7 +436,7 @@ private[ml] object SummaryBuilderImpl extends Logging { var i = 0 val len = currM2n.length while (i < len) { - // We prevent `variance` from negative value caused by numerical error. + // We prevent variance from negative value caused by numerical error. realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index febe7068eb67..8121880cfb23 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -213,7 +213,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - // We prevent `variance` from negative value caused by numerical error. + // We prevent variance from negative value caused by numerical error. realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 From 21e7ff7ea65da1c03b32445405d2bd55346db096 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Sun, 27 Aug 2017 18:24:30 +0800 Subject: [PATCH 5/6] update --- .../org/apache/spark/ml/optim/WeightedLeastSquares.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 4ae85c1f73b1..b706491e19d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -472,7 +472,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - std(l) = math.sqrt(aaValues(i) / wSum - aw * aw) + // We prevent variance from negative value caused by numerical error. + std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0)) i += j j += 1 } @@ -490,7 +491,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - variance(l) = aaValues(i) / wSum - aw * aw + // We prevent variance from negative value caused by numerical error. + variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0) i += j j += 1 } From c40eba38d82893d5604aa66ec9037df706da712d Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 28 Aug 2017 09:30:55 +0800 Subject: [PATCH 6/6] update --- .../org/apache/spark/ml/optim/WeightedLeastSquares.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index b706491e19d3..1ed218aa58bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -439,9 +439,12 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. - * We prevent variance from negative value caused by numerical error. */ - def bStd: Double = math.sqrt(math.max(bbSum / wSum - bBar * bBar, 0.0)) + def bStd: Double = { + // We prevent variance from negative value caused by numerical error. + val variance = math.max(bbSum / wSum - bBar * bBar, 0.0) + math.sqrt(variance) + } /** * Weighted mean of (label * features).