Skip to content

Commit 208d0f2

Browse files
authored
Fix BigFloat for Julia v1.12 (#307)
1 parent 381a59d commit 208d0f2

File tree

2 files changed

+86
-143
lines changed

2 files changed

+86
-143
lines changed

src/implementations/BigFloat.jl

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,20 @@
99

1010
mutability(::Type{BigFloat}) = IsMutable()
1111

12-
# Copied from `deepcopy_internal` implementation in Julia:
13-
# https://github.com/JuliaLang/julia/blob/7d41d1eb610cad490cbaece8887f9bbd2a775021/base/mpfr.jl#L1041-L1050
14-
function mutable_copy(x::BigFloat)
15-
d = x._d
16-
d′ = GC.@preserve d unsafe_string(pointer(d), sizeof(d)) # creates a definitely-new String
17-
return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′)
12+
# These methods are copied from `deepcopy_internal` in `base/mpfr.jl`. We don't
13+
# use `mutable_copy(x) = deepcopy(x)` because this creates an empty `IdDict()`
14+
# which costs some extra allocations. We don't need the IdDict case because we
15+
# never call `mutable_copy` recursively.
16+
@static if VERSION >= v"1.12.0-DEV.1343"
17+
mutable_copy(x::BigFloat) = Base.MPFR._BigFloat(copy(getfield(x, :d)))
18+
else
19+
function mutable_copy(x::BigFloat)
20+
d = x._d
21+
GC.@preserve d begin
22+
d′ = unsafe_string(pointer(d), sizeof(d))
23+
return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′)
24+
end
25+
end
1826
end
1927

2028
const _MPFRRoundingMode = Base.MPFR.MPFRRoundingMode
@@ -297,12 +305,12 @@ function operate_to!(
297305
end
298306

299307
struct DotBuffer{F<:Real}
300-
compensation::F
301-
summation_temp::F
302-
multiplication_temp::F
303-
inner_temp::F
308+
c::F
309+
t::F
310+
input::F
311+
tmp::F
304312

305-
DotBuffer{F}() where {F<:Real} = new{F}(ntuple(i -> F(), Val{4}())...)
313+
DotBuffer{F}() where {F<:Real} = new{F}(zero(F), zero(F), zero(F), zero(F))
306314
end
307315

308316
function buffer_for(
@@ -366,87 +374,61 @@ end
366374
#
367375
# function KahanBabushkaNeumaierSum(input)
368376
# sum = 0.0
369-
#
370377
# # A running compensation for lost low-order bits.
371378
# c = 0.0
372-
#
373379
# for i ∈ eachindex(input)
374380
# t = sum + input[i]
375-
#
376381
# if abs(input[i]) ≤ abs(sum)
377-
# c += (sum - t) + input[i]
382+
# tmp = (sum - t) + input[i]
378383
# else
379-
# c += (input[i] - t) + sum
384+
# tmp = (input[i] - t) + sum
380385
# end
381-
#
386+
# c += tmp
382387
# sum = t
383388
# end
384-
#
385-
# # The result, with the correction only applied once in the very
386-
# # end.
389+
# # The result, with the correction only applied once in the very end.
387390
# sum + c
388391
# end
389-
function buffered_operate_to!(
390-
buf::DotBuffer{F},
391-
sum::F,
392-
::typeof(LinearAlgebra.dot),
393-
x::AbstractVector{F},
394-
y::AbstractVector{F},
395-
) where {F<:BigFloat}
396-
set! = (o, i) -> operate_to!(o, copy, i)
397392

398-
local swap! = function (x::BigFloat, y::BigFloat)
399-
ccall((:mpfr_swap, :libmpfr), Cvoid, (Ref{BigFloat}, Ref{BigFloat}), x, y)
400-
return nothing
393+
# Returns abs(x) <= abs(y) without allocating.
394+
function _abs_lte_abs(x::BigFloat, y::BigFloat)
395+
x_is_neg, y_is_neg = signbit(x), signbit(y)
396+
if x_is_neg != y_is_neg
397+
operate!(-, x)
401398
end
402-
403-
# Returns abs(x) <= abs(y) without allocating.
404-
local abs_lte_abs = function (x::F, y::F)
405-
local x_is_neg = signbit(x)
406-
local y_is_neg = signbit(y)
407-
408-
local x_neg = x_is_neg != y_is_neg
409-
410-
x_neg && operate!(-, x)
411-
412-
local ret = if y_is_neg
413-
y <= x
414-
else
415-
x <= y
416-
end
417-
418-
x_neg && operate!(-, x)
419-
420-
return ret
399+
ret = y_is_neg ? y <= x : x <= y
400+
if x_is_neg != y_is_neg
401+
operate!(-, x)
421402
end
403+
return ret
404+
end
422405

423-
operate!(zero, sum)
424-
operate!(zero, buf.compensation)
425-
426-
for i in 0:(length(x)-1)
427-
set!(buf.multiplication_temp, x[begin+i])
428-
operate!(*, buf.multiplication_temp, y[begin+i])
429-
430-
operate!(zero, buf.summation_temp)
431-
operate_to!(buf.summation_temp, +, buf.multiplication_temp, sum)
432-
433-
if abs_lte_abs(buf.multiplication_temp, sum)
434-
set!(buf.inner_temp, sum)
435-
operate!(-, buf.inner_temp, buf.summation_temp)
436-
operate!(+, buf.inner_temp, buf.multiplication_temp)
406+
function buffered_operate_to!(
407+
buf::DotBuffer{BigFloat},
408+
sum::BigFloat,
409+
::typeof(LinearAlgebra.dot),
410+
x::AbstractVector{BigFloat},
411+
y::AbstractVector{BigFloat},
412+
) # See pseudocode description
413+
operate!(zero, sum) # sum = 0
414+
operate!(zero, buf.c) # c = 0
415+
for (xi, yi) in zip(x, y) # for i in eachindex(input)
416+
operate_to!(buf.input, copy, xi) # input = x[i]
417+
operate!(*, buf.input, yi) # input = x[i] * y[i]
418+
operate_to!(buf.t, +, sum, buf.input) # t = sum + input
419+
if _abs_lte_abs(buf.input, sum) # if |input| < |sum|
420+
operate_to!(buf.tmp, copy, sum) # tmp = sum
421+
operate!(-, buf.tmp, buf.t) # tmp = sum - t
422+
operate!(+, buf.tmp, buf.input) # tmp = (sum - t) + input
437423
else
438-
set!(buf.inner_temp, buf.multiplication_temp)
439-
operate!(-, buf.inner_temp, buf.summation_temp)
440-
operate!(+, buf.inner_temp, sum)
424+
operate_to!(buf.tmp, copy, buf.input) # tmp = input
425+
operate!(-, buf.tmp, buf.t) # tmp = input - t
426+
operate!(+, buf.tmp, sum) # tmp = (input - t) + sum
441427
end
442-
443-
operate!(+, buf.compensation, buf.inner_temp)
444-
445-
swap!(sum, buf.summation_temp)
428+
operate!(+, buf.c, buf.tmp) # c += tmp
429+
operate_to!(sum, copy, buf.t) # sum = t
446430
end
447-
448-
operate!(+, sum, buf.compensation)
449-
431+
operate!(+, sum, buf.c) # sum += c
450432
return sum
451433
end
452434

test/bigfloat_dot.jl

Lines changed: 30 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,6 @@
44
# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain
55
# one at http://mozilla.org/MPL/2.0/.
66

7-
backup_bigfloats(v::AbstractVector{BigFloat}) = map(MA.copy_if_mutable, v)
8-
9-
absolute_error(accurate::Real, approximate::Real) = abs(accurate - approximate)
10-
11-
function relative_error(accurate::Real, approximate::Real)
12-
return absolute_error(accurate, approximate) / abs(accurate)
13-
end
14-
15-
function dotter(x::V, y::V) where {V<:AbstractVector{<:Real}}
16-
let x = x, y = y
17-
() -> LinearAlgebra.dot(x, y)
18-
end
19-
end
20-
21-
function reference_dot(x::V, y::V) where {F<:Real,V<:AbstractVector{F}}
22-
return setprecision(dotter(x, y), F, 8 * precision(F))
23-
end
24-
25-
function dot_test_relative_error(x::V, y::V) where {V<:AbstractVector{BigFloat}}
26-
buf = MA.buffer_for(LinearAlgebra.dot, V, V)
27-
28-
input = (x, y)
29-
backup = map(backup_bigfloats, input)
30-
31-
output = BigFloat()
32-
33-
MA.buffered_operate_to!!(buf, output, LinearAlgebra.dot, input...)
34-
35-
@test input == backup
36-
37-
return relative_error(reference_dot(input...), output)
38-
end
39-
40-
subtracter(s::Real) =
41-
let s = s
42-
x -> x - s
43-
end
44-
45-
our_rand(n::Int, bias::Real) = map(subtracter(bias), rand(BigFloat, n))
46-
47-
function rand_dot_rel_err(size::Int, bias::Real)
48-
x = our_rand(size, bias)
49-
y = our_rand(size, bias)
50-
return dot_test_relative_error(x, y)
51-
end
52-
53-
function max_rand_dot_rel_err(size::Int, bias::Real, iter_cnt::Int)
54-
max_rel_err = zero(BigFloat)
55-
for i in 1:iter_cnt
56-
rel_err = rand_dot_rel_err(size, bias)
57-
<(max_rel_err, rel_err) && (max_rel_err = rel_err)
58-
end
59-
return max_rel_err
60-
end
61-
62-
function max_rand_dot_ulps(size::Int, bias::Real, iter_cnt::Int)
63-
return max_rand_dot_rel_err(size, bias, iter_cnt) / eps(BigFloat)
64-
end
65-
66-
function ulper(size::Int, bias::Real, iter_cnt::Int)
67-
let s = size, b = bias, c = iter_cnt
68-
() -> max_rand_dot_ulps(s, b, c)
69-
end
70-
end
71-
727
@testset "prec:$prec size:$size bias:$bias" for (prec, size, bias) in
738
Iterators.product(
749
# These precisions (in bits) are most probably smaller than what
@@ -78,11 +13,9 @@ end
7813
# precision (except when vector lengths are really huge with
7914
# respect to the precision).
8015
(32, 64),
81-
8216
# Compensated summation should be accurate even for very large
8317
# input vectors, so test that.
8418
(10000,),
85-
8619
# The zero "bias" signifies that the input will be entirely
8720
# nonnegative (drawn from the interval [0, 1]), while a positive
8821
# bias shifts that interval towards negative infinity. We want to
@@ -91,8 +24,36 @@ end
9124
# no guarantee on the relative error in that case.
9225
(0.0, 2^-2, 2^-2 + 2^-3 + 2^-4),
9326
)
94-
iter_cnt = 10
95-
err = setprecision(ulper(size, bias, iter_cnt), BigFloat, prec)
27+
err = setprecision(BigFloat, prec) do
28+
maximum_relative_error = mapreduce(max, 1:10) do _
29+
# Generate some random vectors for dot(x, y) input.
30+
x = rand(BigFloat, size) .- bias
31+
y = rand(BigFloat, size) .- bias
32+
# Copy x and y so that we can check we haven't mutated them after
33+
# the fact.
34+
old_x, old_y = MA.copy_if_mutable(x), MA.copy_if_mutable(y)
35+
# Compute output = dot(x, y)
36+
buf = MA.buffer_for(
37+
LinearAlgebra.dot,
38+
Vector{BigFloat},
39+
Vector{BigFloat},
40+
)
41+
output = BigFloat()
42+
MA.buffered_operate_to!!(buf, output, LinearAlgebra.dot, x, y)
43+
# Check that we haven't mutated x or y
44+
@test old_x == x
45+
@test old_y == y
46+
# Compute dot(x, y) in larger precision. This will be used to
47+
# compare with our `dot`.
48+
accurate = setprecision(BigFloat, 8 * precision(BigFloat)) do
49+
return LinearAlgebra.dot(x, y)
50+
end
51+
# Compute the relative error
52+
return abs(accurate - output) / abs(accurate)
53+
end
54+
# Return estimate for ULP
55+
return maximum_relative_error / eps(BigFloat)
56+
end
9657
@test 0 <= err < 1
9758
end
9859

0 commit comments

Comments
 (0)