|
9 | 9 |
|
10 | 10 | mutability(::Type{BigFloat}) = IsMutable() |
11 | 11 |
|
12 | | -# Copied from `deepcopy_internal` implementation in Julia: |
13 | | -# https://github.com/JuliaLang/julia/blob/7d41d1eb610cad490cbaece8887f9bbd2a775021/base/mpfr.jl#L1041-L1050 |
14 | | -function mutable_copy(x::BigFloat) |
15 | | - d = x._d |
16 | | - d′ = GC.@preserve d unsafe_string(pointer(d), sizeof(d)) # creates a definitely-new String |
17 | | - return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′) |
| 12 | +# These methods are copied from `deepcopy_internal` in `base/mpfr.jl`. We don't |
| 13 | +# use `mutable_copy(x) = deepcopy(x)` because this creates an empty `IdDict()` |
| 14 | +# which costs some extra allocations. We don't need the IdDict case because we |
| 15 | +# never call `mutable_copy` recursively. |
| 16 | +@static if VERSION >= v"1.12.0-DEV.1343" |
| 17 | + mutable_copy(x::BigFloat) = Base.MPFR._BigFloat(copy(getfield(x, :d))) |
| 18 | +else |
| 19 | + function mutable_copy(x::BigFloat) |
| 20 | + d = x._d |
| 21 | + GC.@preserve d begin |
| 22 | + d′ = unsafe_string(pointer(d), sizeof(d)) |
| 23 | + return Base.MPFR._BigFloat(x.prec, x.sign, x.exp, d′) |
| 24 | + end |
| 25 | + end |
18 | 26 | end |
19 | 27 |
|
20 | 28 | const _MPFRRoundingMode = Base.MPFR.MPFRRoundingMode |
@@ -297,12 +305,12 @@ function operate_to!( |
297 | 305 | end |
298 | 306 |
|
299 | 307 | struct DotBuffer{F<:Real} |
300 | | - compensation::F |
301 | | - summation_temp::F |
302 | | - multiplication_temp::F |
303 | | - inner_temp::F |
| 308 | + c::F |
| 309 | + t::F |
| 310 | + input::F |
| 311 | + tmp::F |
304 | 312 |
|
305 | | - DotBuffer{F}() where {F<:Real} = new{F}(ntuple(i -> F(), Val{4}())...) |
| 313 | + DotBuffer{F}() where {F<:Real} = new{F}(zero(F), zero(F), zero(F), zero(F)) |
306 | 314 | end |
307 | 315 |
|
308 | 316 | function buffer_for( |
@@ -366,87 +374,61 @@ end |
366 | 374 | # |
367 | 375 | # function KahanBabushkaNeumaierSum(input) |
368 | 376 | # sum = 0.0 |
369 | | -# |
370 | 377 | # # A running compensation for lost low-order bits. |
371 | 378 | # c = 0.0 |
372 | | -# |
373 | 379 | # for i ∈ eachindex(input) |
374 | 380 | # t = sum + input[i] |
375 | | -# |
376 | 381 | # if abs(input[i]) ≤ abs(sum) |
377 | | -# c += (sum - t) + input[i] |
| 382 | +# tmp = (sum - t) + input[i] |
378 | 383 | # else |
379 | | -# c += (input[i] - t) + sum |
| 384 | +# tmp = (input[i] - t) + sum |
380 | 385 | # end |
381 | | -# |
| 386 | +# c += tmp |
382 | 387 | # sum = t |
383 | 388 | # end |
384 | | -# |
385 | | -# # The result, with the correction only applied once in the very |
386 | | -# # end. |
| 389 | +# # The result, with the correction only applied once in the very end. |
387 | 390 | # sum + c |
388 | 391 | # end |
389 | | -function buffered_operate_to!( |
390 | | - buf::DotBuffer{F}, |
391 | | - sum::F, |
392 | | - ::typeof(LinearAlgebra.dot), |
393 | | - x::AbstractVector{F}, |
394 | | - y::AbstractVector{F}, |
395 | | -) where {F<:BigFloat} |
396 | | - set! = (o, i) -> operate_to!(o, copy, i) |
397 | 392 |
|
398 | | - local swap! = function (x::BigFloat, y::BigFloat) |
399 | | - ccall((:mpfr_swap, :libmpfr), Cvoid, (Ref{BigFloat}, Ref{BigFloat}), x, y) |
400 | | - return nothing |
| 393 | +# Returns abs(x) <= abs(y) without allocating. |
| 394 | +function _abs_lte_abs(x::BigFloat, y::BigFloat) |
| 395 | + x_is_neg, y_is_neg = signbit(x), signbit(y) |
| 396 | + if x_is_neg != y_is_neg |
| 397 | + operate!(-, x) |
401 | 398 | end |
402 | | - |
403 | | - # Returns abs(x) <= abs(y) without allocating. |
404 | | - local abs_lte_abs = function (x::F, y::F) |
405 | | - local x_is_neg = signbit(x) |
406 | | - local y_is_neg = signbit(y) |
407 | | - |
408 | | - local x_neg = x_is_neg != y_is_neg |
409 | | - |
410 | | - x_neg && operate!(-, x) |
411 | | - |
412 | | - local ret = if y_is_neg |
413 | | - y <= x |
414 | | - else |
415 | | - x <= y |
416 | | - end |
417 | | - |
418 | | - x_neg && operate!(-, x) |
419 | | - |
420 | | - return ret |
| 399 | + ret = y_is_neg ? y <= x : x <= y |
| 400 | + if x_is_neg != y_is_neg |
| 401 | + operate!(-, x) |
421 | 402 | end |
| 403 | + return ret |
| 404 | +end |
422 | 405 |
|
423 | | - operate!(zero, sum) |
424 | | - operate!(zero, buf.compensation) |
425 | | - |
426 | | - for i in 0:(length(x)-1) |
427 | | - set!(buf.multiplication_temp, x[begin+i]) |
428 | | - operate!(*, buf.multiplication_temp, y[begin+i]) |
429 | | - |
430 | | - operate!(zero, buf.summation_temp) |
431 | | - operate_to!(buf.summation_temp, +, buf.multiplication_temp, sum) |
432 | | - |
433 | | - if abs_lte_abs(buf.multiplication_temp, sum) |
434 | | - set!(buf.inner_temp, sum) |
435 | | - operate!(-, buf.inner_temp, buf.summation_temp) |
436 | | - operate!(+, buf.inner_temp, buf.multiplication_temp) |
| 406 | +function buffered_operate_to!( |
| 407 | + buf::DotBuffer{BigFloat}, |
| 408 | + sum::BigFloat, |
| 409 | + ::typeof(LinearAlgebra.dot), |
| 410 | + x::AbstractVector{BigFloat}, |
| 411 | + y::AbstractVector{BigFloat}, |
| 412 | +) # See pseudocode description |
| 413 | + operate!(zero, sum) # sum = 0 |
| 414 | + operate!(zero, buf.c) # c = 0 |
| 415 | + for (xi, yi) in zip(x, y) # for i in eachindex(input) |
| 416 | + operate_to!(buf.input, copy, xi) # input = x[i] |
| 417 | + operate!(*, buf.input, yi) # input = x[i] * y[i] |
| 418 | + operate_to!(buf.t, +, sum, buf.input) # t = sum + input |
| 419 | + if _abs_lte_abs(buf.input, sum) # if |input| < |sum| |
| 420 | + operate_to!(buf.tmp, copy, sum) # tmp = sum |
| 421 | + operate!(-, buf.tmp, buf.t) # tmp = sum - t |
| 422 | + operate!(+, buf.tmp, buf.input) # tmp = (sum - t) + input |
437 | 423 | else |
438 | | - set!(buf.inner_temp, buf.multiplication_temp) |
439 | | - operate!(-, buf.inner_temp, buf.summation_temp) |
440 | | - operate!(+, buf.inner_temp, sum) |
| 424 | + operate_to!(buf.tmp, copy, buf.input) # tmp = input |
| 425 | + operate!(-, buf.tmp, buf.t) # tmp = input - t |
| 426 | + operate!(+, buf.tmp, sum) # tmp = (input - t) + sum |
441 | 427 | end |
442 | | - |
443 | | - operate!(+, buf.compensation, buf.inner_temp) |
444 | | - |
445 | | - swap!(sum, buf.summation_temp) |
| 428 | + operate!(+, buf.c, buf.tmp) # c += tmp |
| 429 | + operate_to!(sum, copy, buf.t) # sum = t |
446 | 430 | end |
447 | | - |
448 | | - operate!(+, sum, buf.compensation) |
449 | | - |
| 431 | + operate!(+, sum, buf.c) # sum += c |
450 | 432 | return sum |
451 | 433 | end |
452 | 434 |
|
|
0 commit comments