Skip to content

Commit 41d50fd

Browse files
committed
exact BigFloat to IEEE FP conversion in pure Julia
There's lots of code, but most of it seems like it will be useful in general. For example, I think I'll use the changes in float.jl and rounding.jl to improve the #49749 PR. The changes in float.jl could also be used to refactor float.jl to remove many magic constants. Benchmarking script: ```julia using BenchmarkTools f(::Type{T} = BigFloat, n::Int = 2000) where {T} = rand(T, n) g!(u, v) = map!(eltype(u), u, v) @Btime g!(u, v) setup=(u = f(Float16); v = f();) @Btime g!(u, v) setup=(u = f(Float32); v = f();) @Btime g!(u, v) setup=(u = f(Float64); v = f();) ``` On master (dc06468): ``` 46.116 μs (0 allocations: 0 bytes) 38.842 μs (0 allocations: 0 bytes) 37.039 μs (0 allocations: 0 bytes) ``` With both this commit and #50674 applied: ``` 42.870 μs (0 allocations: 0 bytes) 42.950 μs (0 allocations: 0 bytes) 42.158 μs (0 allocations: 0 bytes) ``` So, with this benchmark at least, on an AMD Zen 2 laptop, conversion to `Float16` is faster, but there's a slowdown for `Float32` and `Float64`. Fixes #50642 (exact conversion to `Float16`)
1 parent dc06468 commit 41d50fd

File tree

6 files changed

+342
-28
lines changed

6 files changed

+342
-28
lines changed

base/Base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ include("hashing.jl")
224224
include("rounding.jl")
225225
using .Rounding
226226
include("div.jl")
227+
include("rawbigints.jl")
227228
include("float.jl")
228229
include("twiceprecision.jl")
229230
include("complex.jl")

base/float.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,61 @@ i.e. the maximum integer value representable by [`exponent_bits(T)`](@ref) bits.
137137
"""
138138
function exponent_raw_max end
139139

140+
"""
141+
IEEE 754 definition of the minimum exponent.
142+
"""
143+
ieee754_exponent_min(::Type{T}) where {T<:IEEEFloat} = Int(1 - exponent_max(T))::Int
144+
145+
exponent_min(::Type{Float16}) = ieee754_exponent_min(Float16)
146+
exponent_min(::Type{Float32}) = ieee754_exponent_min(Float32)
147+
exponent_min(::Type{Float64}) = ieee754_exponent_min(Float64)
148+
149+
function ieee754_representation(
150+
::Type{F}, sign_bit::Bool, exponent_field::Integer, significand_field::Integer
151+
) where {F<:IEEEFloat}
152+
T = uinttype(F)
153+
ret::T = sign_bit
154+
ret <<= exponent_bits(F)
155+
ret |= exponent_field
156+
ret <<= significand_bits(F)
157+
ret |= significand_field
158+
end
159+
160+
# NaN or an infinity
161+
function ieee754_representation(
162+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:nan}
163+
) where {F<:IEEEFloat}
164+
ieee754_representation(F, sign_bit, exponent_raw_max(F), significand_field)
165+
end
166+
167+
# NaN with default payload
168+
function ieee754_representation(
169+
::Type{F}, sign_bit::Bool, ::Val{:nan}
170+
) where {F<:IEEEFloat}
171+
ieee754_representation(F, sign_bit, one(uinttype(F)) << (significand_bits(F) - 1), Val(:nan))
172+
end
173+
174+
# Infinity
175+
function ieee754_representation(
176+
::Type{F}, sign_bit::Bool, ::Val{:inf}
177+
) where {F<:IEEEFloat}
178+
ieee754_representation(F, sign_bit, false, Val(:nan))
179+
end
180+
181+
# Subnormal or zero
182+
function ieee754_representation(
183+
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:subnormal}
184+
) where {F<:IEEEFloat}
185+
ieee754_representation(F, sign_bit, false, significand_field)
186+
end
187+
188+
# Zero
189+
function ieee754_representation(
190+
::Type{F}, sign_bit::Bool, ::Val{:zero}
191+
) where {F<:IEEEFloat}
192+
ieee754_representation(F, sign_bit, false, Val(:subnormal))
193+
end
194+
140195
"""
141196
uabs(x::Integer)
142197

base/mpfr.jl

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ import
1717
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
1818
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
1919
isone, big, _string_n, decompose, minmax,
20-
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand
20+
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
21+
uinttype, exponent_max, exponent_min, ieee754_representation,
22+
RawBigIntRoundingIncrementHelper, truncated, RawBigInt
2123

2224

2325
using .Base.Libc
24-
import ..Rounding: rounding_raw, setrounding_raw
26+
import ..Rounding:
27+
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
28+
tie_breaker_is_to_even, correct_rounding_requires_increment
2529

2630
import ..GMP: ClongMax, CulongMax, CdoubleMax, Limb, libgmp
2731

@@ -89,6 +93,21 @@ function convert(::Type{RoundingMode}, r::MPFRRoundingMode)
8993
end
9094
end
9195

96+
rounds_to_nearest(m::MPFRRoundingMode) = m == MPFRRoundNearest
97+
function rounds_away_from_zero(m::MPFRRoundingMode, sign_bit::Bool)
98+
if m == MPFRRoundToZero
99+
false
100+
elseif m == MPFRRoundUp
101+
!sign_bit
102+
elseif m == MPFRRoundDown
103+
sign_bit
104+
else
105+
# Assuming `m == MPFRRoundFromZero`
106+
true
107+
end
108+
end
109+
tie_breaker_is_to_even(::MPFRRoundingMode) = true
110+
92111
const ROUNDING_MODE = Ref{MPFRRoundingMode}(MPFRRoundNearest)
93112
const DEFAULT_PRECISION = Ref{Clong}(256)
94113

@@ -130,6 +149,9 @@ mutable struct BigFloat <: AbstractFloat
130149
end
131150
end
132151

152+
# The rounding mode here shouldn't matter.
153+
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)
154+
133155
rounding_raw(::Type{BigFloat}) = ROUNDING_MODE[]
134156
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r
135157

@@ -380,35 +402,56 @@ function (::Type{T})(x::BigFloat) where T<:Integer
380402
trunc(T,x)
381403
end
382404

383-
## BigFloat -> AbstractFloat
384-
_cpynansgn(x::AbstractFloat, y::BigFloat) = isnan(x) && signbit(x) != signbit(y) ? -x : x
385-
386-
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
387-
_cpynansgn(ccall((:mpfr_get_d,libmpfr), Float64, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
388-
Float64(x::BigFloat, r::RoundingMode) = Float64(x, convert(MPFRRoundingMode, r))
389-
390-
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
391-
_cpynansgn(ccall((:mpfr_get_flt,libmpfr), Float32, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
392-
Float32(x::BigFloat, r::RoundingMode) = Float32(x, convert(MPFRRoundingMode, r))
393-
394-
function Float16(x::BigFloat) :: Float16
395-
res = Float32(x)
396-
resi = reinterpret(UInt32, res)
397-
if (resi&0x7fffffff) < 0x38800000 # if Float16(res) is subnormal
398-
#shift so that the mantissa lines up where it would for normal Float16
399-
shift = 113-((resi & 0x7f800000)>>23)
400-
if shift<23
401-
resi |= 0x0080_0000 # set implicit bit
402-
resi >>= shift
405+
function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
406+
sb = signbit(x)
407+
is_zero = iszero(x)
408+
is_inf = isinf(x)
409+
is_nan = isnan(x)
410+
is_regular = !is_zero & !is_inf & !is_nan
411+
ieee_exp = Int(x.exp) - 1
412+
ieee_precision = precision(T)
413+
ieee_exp_max = exponent_max(T)
414+
ieee_exp_min = exponent_min(T)
415+
ieee_exp_min_subnormal = ieee_exp_min - ieee_precision + 1
416+
exp_diff = ieee_exp - ieee_exp_min
417+
is_normal = 0 exp_diff
418+
419+
# Note: sufficient but not necessary, depending on the rounding mode
420+
rounds_to_inf = is_inf | (is_regular & (ieee_exp_max < ieee_exp))
421+
rounds_to_zero = is_zero | (is_regular & (ieee_exp < ieee_exp_min_subnormal))
422+
rounds_to_regular = !is_nan & !rounds_to_inf & !rounds_to_zero
423+
424+
U = uinttype(T)
425+
426+
ret_u = if rounds_to_regular
427+
v = RawBigInt(x.d, significand_limb_count(x))
428+
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
429+
signif = truncated(U, v, len)
430+
c = 8*sizeof(U) - len + 1
431+
is_normal && (signif = (signif << c) >> c) # implicit bit convention
432+
rh = RawBigIntRoundingIncrementHelper(v, len)
433+
incr = correct_rounding_requires_increment(rh, rm, sb)
434+
exp_field = max(exp_diff, 0) + is_normal
435+
ieee754_representation(T, sb, exp_field, signif) + incr
436+
else
437+
if rounds_to_zero
438+
ieee754_representation(T, sb, Val(:zero))
439+
elseif rounds_to_inf
440+
ieee754_representation(T, sb, Val(:inf))
441+
else
442+
ieee754_representation(T, sb, Val(:nan))
403443
end
404-
end
405-
if (resi & 0x1fff == 0x1000) # if we are halfway between 2 Float16 values
406-
# adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer
407-
res = nextfloat(res, cmp(x, res))
408-
end
409-
return res
444+
end::U
445+
reinterpret(T, ret_u)
410446
end
411447

448+
Float16(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float16, x, r)
449+
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float32, x, r)
450+
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float64, x, r)
451+
Float16(x::BigFloat, r::RoundingMode) = to_ieee754(Float16, x, r)
452+
Float32(x::BigFloat, r::RoundingMode) = to_ieee754(Float32, x, r)
453+
Float64(x::BigFloat, r::RoundingMode) = to_ieee754(Float64, x, r)
454+
412455
promote_rule(::Type{BigFloat}, ::Type{<:Real}) = BigFloat
413456
promote_rule(::Type{BigInt}, ::Type{<:AbstractFloat}) = BigFloat
414457
promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat

base/rawbigints.jl

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
"""
4+
Segment of raw words of bits interpreted as a big integer. Less
5+
significant words come first. Each word is in machine-native bit-order.
6+
"""
7+
struct RawBigInt{T<:Unsigned}
8+
d::Ptr{T}
9+
word_count::Int
10+
11+
function RawBigInt{T}(d::Ptr{T}, word_count::Int) where {T<:Unsigned}
12+
new{T}(d, word_count)
13+
end
14+
end
15+
16+
RawBigInt(d::Ptr{T}, word_count::Int) where {T<:Unsigned} = RawBigInt{T}(d, word_count)
17+
elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
18+
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
19+
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
20+
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
21+
reversed_index(n::Int, i::Int) = n - i - 1
22+
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
23+
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)
24+
25+
"""
26+
`i` is the zero-based index of the wanted word in `x`, starting from
27+
the less significant words.
28+
"""
29+
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
30+
unsafe_load(x.d, i + 1)
31+
end
32+
33+
function get_elem(x, i::Int, v::Val, ::Val{:descending})
34+
j = reversed_index(x, i, v)
35+
get_elem(x, j, v, Val(:ascending))
36+
end
37+
38+
word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))
39+
40+
word_is_nonzero(x::RawBigInt, v::Val) = let x = x
41+
i -> word_is_nonzero(x, i, v)
42+
end
43+
44+
"""
45+
Returns a `Bool` indicating whether the `len` least significant words
46+
of `x` are nonzero.
47+
"""
48+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
49+
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
50+
end
51+
52+
"""
53+
Returns a `Bool` indicating whether the `len` least significant bits of
54+
the `i`-th (zero-based index) word of `x` are nonzero.
55+
"""
56+
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
57+
!iszero(len) &&
58+
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
59+
end
60+
61+
"""
62+
Returns a `Bool` indicating whether the `len` least significant bits of
63+
`x` are nonzero.
64+
"""
65+
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
66+
if 0 < len
67+
word_count, bit_count_in_word = split_bit_index(x, len)
68+
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
69+
tail_is_nonzero(x, word_count, Val(:words))
70+
else
71+
false
72+
end::Bool
73+
end
74+
75+
"""
76+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
77+
"""
78+
function get_elem(x::Unsigned, i::Int, ::Val{:bits}, ::Val{:ascending})
79+
(x >>> i) % Bool
80+
end
81+
82+
"""
83+
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
84+
"""
85+
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
86+
vb = Val(:bits)
87+
if 0 i < elem_count(x, vb)
88+
word_index, bit_index_in_word = split_bit_index(x, i)
89+
word = get_elem(x, word_index, Val(:words), v)
90+
get_elem(word, bit_index_in_word, vb, v)
91+
else
92+
false
93+
end::Bool
94+
end
95+
96+
"""
97+
Returns an integer of type `R`, consisting of the `len` most
98+
significant bits of `x`.
99+
"""
100+
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
101+
ret = zero(R)
102+
if 0 < len
103+
word_count, bit_count_in_word = split_bit_index(x, len)
104+
k = word_length(x)
105+
vals = (Val(:words), Val(:descending))
106+
107+
for w 0:(word_count - 1)
108+
ret <<= k
109+
word = get_elem(x, w, vals...)
110+
ret |= R(word)
111+
end
112+
113+
if !iszero(bit_count_in_word)
114+
ret <<= bit_count_in_word
115+
wrd = get_elem(x, word_count, vals...)
116+
ret |= R(wrd >>> (k - bit_count_in_word))
117+
end
118+
end
119+
ret::R
120+
end
121+
122+
struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
123+
n::RawBigInt{T}
124+
trunc_len::Int
125+
126+
final_bit::Bool
127+
round_bit::Bool
128+
129+
function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
130+
vals = (Val(:bits), Val(:descending))
131+
f = get_elem(n, len - 1, vals...)
132+
r = get_elem(n, len , vals...)
133+
new{T}(n, len, f, r)
134+
end
135+
end
136+
137+
function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
138+
RawBigIntRoundingIncrementHelper{T}(n, len)
139+
end
140+
141+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit
142+
143+
(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit
144+
145+
function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
146+
v = Val(:bits)
147+
n = h.n
148+
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
149+
end

0 commit comments

Comments
 (0)