Skip to content

Commit dd87181

Browse files
author
Ian Graves
committed
8214761: Bug in parallel Kahan summation implementation
Reviewed-by: darcy
1 parent 7fff22a commit dd87181

File tree

5 files changed

+235
-10
lines changed

5 files changed

+235
-10
lines changed

src/java.base/share/classes/java/util/DoubleSummaryStatistics.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ public void combine(DoubleSummaryStatistics other) {
156156
count += other.count;
157157
simpleSum += other.simpleSum;
158158
sumWithCompensation(other.sum);
159-
sumWithCompensation(other.sumCompensation);
159+
160+
// Subtract compensation bits
161+
sumWithCompensation(-other.sumCompensation);
160162
min = Math.min(min, other.min);
161163
max = Math.max(max, other.max);
162164
}
@@ -241,7 +243,7 @@ public final long getCount() {
241243
*/
242244
public final double getSum() {
243245
// Better error bounds to add both terms as the final sum
244-
double tmp = sum + sumCompensation;
246+
double tmp = sum - sumCompensation;
245247
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
246248
// If the compensated sum is spuriously NaN from
247249
// accumulating one or more same-signed infinite values,

src/java.base/share/classes/java/util/stream/Collectors.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,8 @@ public static<T,A,R,RR> Collector<T,A,RR> collectingAndThen(Collector<T,A,R> dow
734734
a[2] += val;},
735735
(a, b) -> { sumWithCompensation(a, b[0]);
736736
a[2] += b[2];
737-
return sumWithCompensation(a, b[1]); },
737+
// Subtract compensation bits
738+
return sumWithCompensation(a, -b[1]); },
738739
a -> computeFinalSum(a),
739740
CH_NOID);
740741
}
@@ -765,8 +766,8 @@ static double[] sumWithCompensation(double[] intermediateSum, double value) {
765766
* correctly-signed infinity stored in the simple sum.
766767
*/
767768
static double computeFinalSum(double[] summands) {
768-
// Better error bounds to add both terms as the final sum
769-
double tmp = summands[0] + summands[1];
769+
// Final sum with better error bounds subtract second summand as it is negated
770+
double tmp = summands[0] - summands[1];
770771
double simpleSum = summands[summands.length - 1];
771772
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
772773
return simpleSum;
@@ -840,13 +841,19 @@ static double computeFinalSum(double[] summands) {
840841
/*
841842
* In the arrays allocated for the collect operation, index 0
842843
* holds the high-order bits of the running sum, index 1 holds
843-
* the low-order bits of the sum computed via compensated
844+
* the negated low-order bits of the sum computed via compensated
844845
* summation, and index 2 holds the number of values seen.
845846
*/
846847
return new CollectorImpl<>(
847848
() -> new double[4],
848849
(a, t) -> { double val = mapper.applyAsDouble(t); sumWithCompensation(a, val); a[2]++; a[3]+= val;},
849-
(a, b) -> { sumWithCompensation(a, b[0]); sumWithCompensation(a, b[1]); a[2] += b[2]; a[3] += b[3]; return a; },
850+
(a, b) -> {
851+
sumWithCompensation(a, b[0]);
852+
// Subtract compensation bits
853+
sumWithCompensation(a, -b[1]);
854+
a[2] += b[2]; a[3] += b[3];
855+
return a;
856+
},
850857
a -> (a[2] == 0) ? 0.0d : (computeFinalSum(a) / a[2]),
851858
CH_NOID);
852859
}

src/java.base/share/classes/java/util/stream/DoublePipeline.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ public final double sum() {
442442
/*
443443
* In the arrays allocated for the collect operation, index 0
444444
* holds the high-order bits of the running sum, index 1 holds
445-
* the low-order bits of the sum computed via compensated
445+
* the negated low-order bits of the sum computed via compensated
446446
* summation, and index 2 holds the simple sum used to compute
447447
* the proper result if the stream contains infinite values of
448448
* the same sign.
@@ -454,7 +454,8 @@ public final double sum() {
454454
},
455455
(ll, rr) -> {
456456
Collectors.sumWithCompensation(ll, rr[0]);
457-
Collectors.sumWithCompensation(ll, rr[1]);
457+
// Subtract compensation bits
458+
Collectors.sumWithCompensation(ll, -rr[1]);
458459
ll[2] += rr[2];
459460
});
460461

@@ -497,7 +498,8 @@ public final OptionalDouble average() {
497498
},
498499
(ll, rr) -> {
499500
Collectors.sumWithCompensation(ll, rr[0]);
500-
Collectors.sumWithCompensation(ll, rr[1]);
501+
// Subtract compensation bits
502+
Collectors.sumWithCompensation(ll, -rr[1]);
501503
ll[2] += rr[2];
502504
ll[3] += rr[3];
503505
});
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test
26+
* @bug 8214761
27+
* @run testng CompensatedSums
28+
* @summary
29+
*/
30+
31+
import java.util.Random;
32+
import java.util.function.BiConsumer;
33+
import java.util.function.ObjDoubleConsumer;
34+
import java.util.function.Supplier;
35+
import java.util.stream.Collectors;
36+
import java.util.stream.DoubleStream;
37+
38+
import static org.testng.Assert.assertTrue;
39+
40+
import org.testng.Assert;
41+
import org.testng.annotations.Test;
42+
43+
public class CompensatedSums {
44+
45+
@Test
46+
public void testCompensatedSums() {
47+
double naive = 0;
48+
double jdkSequentialStreamError = 0;
49+
double goodSequentialStreamError = 0;
50+
double jdkParallelStreamError = 0;
51+
double goodParallelStreamError = 0;
52+
double badParallelStreamError = 0;
53+
54+
for (int loop = 0; loop < 100; loop++) {
55+
// sequence of random numbers of varying magnitudes, both positive and negative
56+
double[] rand = new Random().doubles(1_000_000)
57+
.map(Math::log)
58+
.map(x -> (Double.doubleToLongBits(x) % 2 == 0) ? x : -x)
59+
.toArray();
60+
61+
// base case: standard Kahan summation
62+
double[] sum = new double[2];
63+
for (int i=0; i < rand.length; i++) {
64+
sumWithCompensation(sum, rand[i]);
65+
}
66+
67+
// All error is the squared difference of the standard Kahan Sum vs JDK Stream sum implementation
68+
// Older less accurate implementations included here as the baseline.
69+
70+
// squared error of naive sum by reduction - should be large
71+
naive += Math.pow(DoubleStream.of(rand).reduce((x, y) -> x+y).getAsDouble() - sum[0], 2);
72+
73+
// squared error of sequential sum - should be 0
74+
jdkSequentialStreamError += Math.pow(DoubleStream.of(rand).sum() - sum[0], 2);
75+
76+
goodSequentialStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).collect(doubleSupplier,objDoubleConsumer,goodCollectorConsumer)) - sum[0], 2);
77+
78+
// squared error of parallel sum from the JDK
79+
jdkParallelStreamError += Math.pow(DoubleStream.of(rand).parallel().sum() - sum[0], 2);
80+
81+
// squared error of parallel sum
82+
goodParallelStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(doubleSupplier,objDoubleConsumer,goodCollectorConsumer)) - sum[0], 2);
83+
84+
// the bad parallel stream
85+
badParallelStreamError += Math.pow(computeFinalSum(DoubleStream.of(rand).parallel().collect(doubleSupplier,objDoubleConsumer,badCollectorConsumer)) - sum[0], 2);
86+
87+
88+
}
89+
90+
Assert.assertEquals(goodSequentialStreamError, 0.0);
91+
Assert.assertEquals(goodSequentialStreamError, jdkSequentialStreamError);
92+
93+
Assert.assertTrue(jdkParallelStreamError <= goodParallelStreamError);
94+
Assert.assertTrue(badParallelStreamError > goodParallelStreamError);
95+
96+
Assert.assertTrue(naive > jdkSequentialStreamError);
97+
Assert.assertTrue(naive > jdkParallelStreamError);
98+
99+
}
100+
101+
// from OpenJDK8 Collectors, unmodified
102+
static double[] sumWithCompensation(double[] intermediateSum, double value) {
103+
double tmp = value - intermediateSum[1];
104+
double sum = intermediateSum[0];
105+
double velvel = sum + tmp; // Little wolf of rounding error
106+
intermediateSum[1] = (velvel - sum) - tmp;
107+
intermediateSum[0] = velvel;
108+
return intermediateSum;
109+
}
110+
111+
// from OpenJDK8 Collectors, unmodified
112+
static double computeFinalSum(double[] summands) {
113+
double tmp = summands[0] + summands[1];
114+
double simpleSum = summands[summands.length - 1];
115+
if (Double.isNaN(tmp) && Double.isInfinite(simpleSum))
116+
return simpleSum;
117+
else
118+
return tmp;
119+
}
120+
121+
//Suppliers and consumers for Double Stream summation collection.
122+
static Supplier<double[]> doubleSupplier = () -> new double[3];
123+
static ObjDoubleConsumer<double[]> objDoubleConsumer = (double[] ll, double d) -> {
124+
sumWithCompensation(ll, d);
125+
ll[2] += d;
126+
};
127+
static BiConsumer<double[], double[]> badCollectorConsumer =
128+
(ll, rr) -> {
129+
sumWithCompensation(ll, rr[0]);
130+
sumWithCompensation(ll, rr[1]);
131+
ll[2] += rr[2];
132+
};
133+
134+
static BiConsumer<double[], double[]> goodCollectorConsumer =
135+
(ll, rr) -> {
136+
sumWithCompensation(ll, rr[0]);
137+
sumWithCompensation(ll, -rr[1]);
138+
ll[2] += rr[2];
139+
};
140+
141+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* This code is free software; you can redistribute it and/or modify it
6+
* under the terms of the GNU General Public License version 2 only, as
7+
* published by the Free Software Foundation.
8+
*
9+
* This code is distributed in the hope that it will be useful, but WITHOUT
10+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11+
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12+
* version 2 for more details (a copy is included in the LICENSE file that
13+
* accompanied this code).
14+
*
15+
* You should have received a copy of the GNU General Public License version
16+
* 2 along with this work; if not, write to the Free Software Foundation,
17+
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18+
*
19+
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20+
* or visit www.oracle.com if you need additional information or have any
21+
* questions.
22+
*/
23+
24+
/*
25+
* @test
26+
* @bug 8214761
27+
* @run testng NegativeCompensation
28+
* @summary When combining two DoubleSummaryStatistics, the compensation
29+
* has to be subtracted.
30+
*/
31+
32+
import java.util.DoubleSummaryStatistics;
33+
import org.testng.annotations.Test;
34+
import static org.testng.Assert.assertEquals;
35+
import static org.testng.Assert.assertTrue;
36+
37+
public class NegativeCompensation {
38+
static final double VAL = 1.000000001;
39+
static final int LOG_ITER = 21;
40+
41+
@Test
42+
public static void testErrorComparision() {
43+
DoubleSummaryStatistics stat0 = new DoubleSummaryStatistics();
44+
DoubleSummaryStatistics stat1 = new DoubleSummaryStatistics();
45+
DoubleSummaryStatistics stat2 = new DoubleSummaryStatistics();
46+
47+
stat1.accept(VAL);
48+
stat1.accept(VAL);
49+
stat2.accept(VAL);
50+
stat2.accept(VAL);
51+
stat2.accept(VAL);
52+
53+
for (int i = 0; i < LOG_ITER; ++i) {
54+
stat1.combine(stat2);
55+
stat2.combine(stat1);
56+
}
57+
58+
for (long i = 0, iend = stat2.getCount(); i < iend; ++i) {
59+
stat0.accept(VAL);
60+
}
61+
62+
double res = 0;
63+
for(long i = 0, iend = stat2.getCount(); i < iend; ++i) {
64+
res += VAL;
65+
}
66+
67+
double absErrN = Math.abs(res - stat2.getSum());
68+
double absErr = Math.abs(stat0.getSum() - stat2.getSum());
69+
assertTrue(absErrN >= absErr,
70+
"Naive sum error is not greater than or equal to Summary sum");
71+
assertEquals(absErr, 0.0, "Absolute error is not zero");
72+
}
73+
}

0 commit comments

Comments
 (0)