From 8881ef9dc39a68bec993629c2ffb105ae46f437b Mon Sep 17 00:00:00 2001 From: Trevor Gross Date: Thu, 23 Jan 2025 08:28:58 +0000 Subject: [PATCH] Start converting `fma` to a generic function This is the first step toward making `fma` usable for `f128`, and possibly `f32` on platforms where growing to `f64` is not fast. This does not yet work for anything other than `f64`. --- etc/function-definitions.json | 6 +- src/math/fma.rs | 192 +------------------------- src/math/generic/fma.rs | 227 +++++++++++++++++++++++++++++++ src/math/generic/mod.rs | 2 + src/math/support/float_traits.rs | 4 +- src/math/support/int_traits.rs | 39 ++++++ 6 files changed, 278 insertions(+), 192 deletions(-) create mode 100644 src/math/generic/fma.rs diff --git a/etc/function-definitions.json b/etc/function-definitions.json index a1d3adf59..243862075 100644 --- a/etc/function-definitions.json +++ b/etc/function-definitions.json @@ -344,13 +344,15 @@ }, "fma": { "sources": [ - "src/math/fma.rs" + "src/math/fma.rs", + "src/math/generic/fma.rs" ], "type": "f64" }, "fmaf": { "sources": [ - "src/math/fmaf.rs" + "src/math/fmaf.rs", + "src/math/generic/fma.rs" ], "type": "f32" }, diff --git a/src/math/fma.rs b/src/math/fma.rs index 826143d5a..69cc3eb67 100644 --- a/src/math/fma.rs +++ b/src/math/fma.rs @@ -1,195 +1,9 @@ -use core::{f32, f64}; - -use super::scalbn; - -const ZEROINFNAN: i32 = 0x7ff - 0x3ff - 52 - 1; - -struct Num { - m: u64, - e: i32, - sign: i32, -} - -fn normalize(x: f64) -> Num { - let x1p63: f64 = f64::from_bits(0x43e0000000000000); // 0x1p63 === 2 ^ 63 - - let mut ix: u64 = x.to_bits(); - let mut e: i32 = (ix >> 52) as i32; - let sign: i32 = e & 0x800; - e &= 0x7ff; - if e == 0 { - ix = (x * x1p63).to_bits(); - e = (ix >> 52) as i32 & 0x7ff; - e = if e != 0 { e - 63 } else { 0x800 }; - } - ix &= (1 << 52) - 1; - ix |= 1 << 52; - ix <<= 1; - e -= 0x3ff + 52 + 1; - Num { m: ix, e, sign } -} - -#[inline] -fn mul(x: u64, y: u64) -> (u64, u64) { - let t = (x as u128).wrapping_mul(y as u128); - ((t >> 64) as u64, t as u64) -} - -/// Floating multiply add (f64) +/// Fused multiply add (f64) /// -/// Computes `(x*y)+z`, rounded as one ternary operation: -/// Computes the value (as if) to infinite precision and rounds once to the result format, -/// according to the rounding mode characterized by the value of FLT_ROUNDS. +/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision). #[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] pub fn fma(x: f64, y: f64, z: f64) -> f64 { - let x1p63: f64 = f64::from_bits(0x43e0000000000000); // 0x1p63 === 2 ^ 63 - let x0_ffffff8p_63 = f64::from_bits(0x3bfffffff0000000); // 0x0.ffffff8p-63 - - /* normalize so top 10bits and last bit are 0 */ - let nx = normalize(x); - let ny = normalize(y); - let nz = normalize(z); - - if nx.e >= ZEROINFNAN || ny.e >= ZEROINFNAN { - return x * y + z; - } - if nz.e >= ZEROINFNAN { - if nz.e > ZEROINFNAN { - /* z==0 */ - return x * y + z; - } - return z; - } - - /* mul: r = x*y */ - let zhi: u64; - let zlo: u64; - let (mut rhi, mut rlo) = mul(nx.m, ny.m); - /* either top 20 or 21 bits of rhi and last 2 bits of rlo are 0 */ - - /* align exponents */ - let mut e: i32 = nx.e + ny.e; - let mut d: i32 = nz.e - e; - /* shift bits z<<=kz, r>>=kr, so kz+kr == d, set e = e+kr (== ez-kz) */ - if d > 0 { - if d < 64 { - zlo = nz.m << d; - zhi = nz.m >> (64 - d); - } else { - zlo = 0; - zhi = nz.m; - e = nz.e - 64; - d -= 64; - if d == 0 { - } else if d < 64 { - rlo = (rhi << (64 - d)) | (rlo >> d) | ((rlo << (64 - d)) != 0) as u64; - rhi = rhi >> d; - } else { - rlo = 1; - rhi = 0; - } - } - } else { - zhi = 0; - d = -d; - if d == 0 { - zlo = nz.m; - } else if d < 64 { - zlo = (nz.m >> d) | ((nz.m << (64 - d)) != 0) as u64; - } else { - zlo = 1; - } - } - - /* add */ - let mut sign: i32 = nx.sign ^ ny.sign; - let samesign: bool = (sign ^ nz.sign) == 0; - let mut nonzero: i32 = 1; - if samesign { - /* r += z */ - rlo = rlo.wrapping_add(zlo); - rhi += zhi + (rlo < zlo) as u64; - } else { - /* r -= z */ - let (res, borrow) = rlo.overflowing_sub(zlo); - rlo = res; - rhi = rhi.wrapping_sub(zhi.wrapping_add(borrow as u64)); - if (rhi >> 63) != 0 { - rlo = (rlo as i64).wrapping_neg() as u64; - rhi = (rhi as i64).wrapping_neg() as u64 - (rlo != 0) as u64; - sign = (sign == 0) as i32; - } - nonzero = (rhi != 0) as i32; - } - - /* set rhi to top 63bit of the result (last bit is sticky) */ - if nonzero != 0 { - e += 64; - d = rhi.leading_zeros() as i32 - 1; - /* note: d > 0 */ - rhi = (rhi << d) | (rlo >> (64 - d)) | ((rlo << d) != 0) as u64; - } else if rlo != 0 { - d = rlo.leading_zeros() as i32 - 1; - if d < 0 { - rhi = (rlo >> 1) | (rlo & 1); - } else { - rhi = rlo << d; - } - } else { - /* exact +-0 */ - return x * y + z; - } - e -= d; - - /* convert to double */ - let mut i: i64 = rhi as i64; /* i is in [1<<62,(1<<63)-1] */ - if sign != 0 { - i = -i; - } - let mut r: f64 = i as f64; /* |r| is in [0x1p62,0x1p63] */ - - if e < -1022 - 62 { - /* result is subnormal before rounding */ - if e == -1022 - 63 { - let mut c: f64 = x1p63; - if sign != 0 { - c = -c; - } - if r == c { - /* min normal after rounding, underflow depends - on arch behaviour which can be imitated by - a double to float conversion */ - let fltmin: f32 = (x0_ffffff8p_63 * f32::MIN_POSITIVE as f64 * r) as f32; - return f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * fltmin as f64; - } - /* one bit is lost when scaled, add another top bit to - only round once at conversion if it is inexact */ - if (rhi << 53) != 0 { - i = ((rhi >> 1) | (rhi & 1) | (1 << 62)) as i64; - if sign != 0 { - i = -i; - } - r = i as f64; - r = 2. * r - c; /* remove top bit */ - - /* raise underflow portably, such that it - cannot be optimized away */ - { - let tiny: f64 = f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * r; - r += (tiny * tiny) * (r - r); - } - } - } else { - /* only round once when scaled */ - d = 10; - i = (((rhi >> d) | ((rhi << (64 - d)) != 0) as u64) << d) as i64; - if sign != 0 { - i = -i; - } - r = i as f64; - } - } - scalbn(r, e) + return super::generic::fma(x, y, z); } #[cfg(test)] diff --git a/src/math/generic/fma.rs b/src/math/generic/fma.rs new file mode 100644 index 000000000..3d5459f1a --- /dev/null +++ b/src/math/generic/fma.rs @@ -0,0 +1,227 @@ +use core::{f32, f64}; + +use super::super::support::{DInt, HInt, IntTy}; +use super::super::{CastFrom, CastInto, Float, Int, MinInt}; + +const ZEROINFNAN: i32 = 0x7ff - 0x3ff - 52 - 1; + +/// Fused multiply-add that works when there is not a larger float size available. Currently this +/// is still specialized only for `f64`. Computes `(x * y) + z`. +#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)] +pub fn fma(x: F, y: F, z: F) -> F +where + F: Float + FmaHelper, + F: CastFrom, + F: CastFrom, + F::Int: HInt, + u32: CastInto, +{ + let one = IntTy::::ONE; + let zero = IntTy::::ZERO; + let magic = F::from_parts(false, F::BITS - 1 + F::EXP_BIAS, zero); + + /* normalize so top 10bits and last bit are 0 */ + let nx = Norm::from_float(x); + let ny = Norm::from_float(y); + let nz = Norm::from_float(z); + + if nx.e >= ZEROINFNAN || ny.e >= ZEROINFNAN { + return x * y + z; + } + if nz.e >= ZEROINFNAN { + if nz.e > ZEROINFNAN { + /* z==0 */ + return x * y + z; + } + return z; + } + + /* mul: r = x*y */ + let zhi: F::Int; + let zlo: F::Int; + let (mut rlo, mut rhi) = nx.m.widen_mul(ny.m).lo_hi(); + + /* either top 20 or 21 bits of rhi and last 2 bits of rlo are 0 */ + + /* align exponents */ + let mut e: i32 = nx.e + ny.e; + let mut d: i32 = nz.e - e; + let sbits = F::BITS as i32; + + /* shift bits z<<=kz, r>>=kr, so kz+kr == d, set e = e+kr (== ez-kz) */ + if d > 0 { + if d < sbits { + zlo = nz.m << d; + zhi = nz.m >> (sbits - d); + } else { + zlo = zero; + zhi = nz.m; + e = nz.e - sbits; + d -= sbits; + if d == 0 { + } else if d < sbits { + rlo = (rhi << (sbits - d)) + | (rlo >> d) + | IntTy::::from((rlo << (sbits - d)) != zero); + rhi = rhi >> d; + } else { + rlo = one; + rhi = zero; + } + } + } else { + zhi = zero; + d = -d; + if d == 0 { + zlo = nz.m; + } else if d < sbits { + zlo = (nz.m >> d) | IntTy::::from((nz.m << (sbits - d)) != zero); + } else { + zlo = one; + } + } + + /* add */ + let mut neg = nx.neg ^ ny.neg; + let samesign: bool = !neg ^ nz.neg; + let mut nonzero: i32 = 1; + if samesign { + /* r += z */ + rlo = rlo.wrapping_add(zlo); + rhi += zhi + IntTy::::from(rlo < zlo); + } else { + /* r -= z */ + let (res, borrow) = rlo.overflowing_sub(zlo); + rlo = res; + rhi = rhi.wrapping_sub(zhi.wrapping_add(IntTy::::from(borrow))); + if (rhi >> (F::BITS - 1)) != zero { + rlo = rlo.signed().wrapping_neg().unsigned(); + rhi = rhi.signed().wrapping_neg().unsigned() - IntTy::::from(rlo != zero); + neg = !neg; + } + nonzero = (rhi != zero) as i32; + } + + /* set rhi to top 63bit of the result (last bit is sticky) */ + if nonzero != 0 { + e += sbits; + d = rhi.leading_zeros() as i32 - 1; + /* note: d > 0 */ + rhi = (rhi << d) | (rlo >> (sbits - d)) | IntTy::::from((rlo << d) != zero); + } else if rlo != zero { + d = rlo.leading_zeros() as i32 - 1; + if d < 0 { + rhi = (rlo >> 1) | (rlo & one); + } else { + rhi = rlo << d; + } + } else { + /* exact +-0 */ + return x * y + z; + } + e -= d; + + /* convert to double */ + let mut i: F::SignedInt = rhi.signed(); /* i is in [1<<62,(1<<63)-1] */ + if neg { + i = -i; + } + + let mut r: F = F::cast_from_lossy(i); /* |r| is in [0x1p62,0x1p63] */ + + if e < -(F::EXP_BIAS as i32 - 1) - (sbits - 2) { + /* result is subnormal before rounding */ + if e == -(F::EXP_BIAS as i32 - 1) - (sbits - 1) { + let mut c: F = magic; + if neg { + c = -c; + } + if r == c { + /* min normal after rounding, underflow depends + * on arch behaviour which can be imitated by + * a double to float conversion */ + return r.raise_underflow(); + } + /* one bit is lost when scaled, add another top bit to + * only round once at conversion if it is inexact */ + if (rhi << F::SIG_BITS) != zero { + let iu: F::Int = (rhi >> 1) | (rhi & one) | (one << 62); + i = iu.signed(); + if neg { + i = -i; + } + r = F::cast_from_lossy(i); + r = F::cast_from(2i8) * r - c; /* remove top bit */ + + /* raise underflow portably, such that it + * cannot be optimized away */ + r += r.raise_underflow2(); + } + } else { + /* only round once when scaled */ + d = 10; + i = (((rhi >> d) | IntTy::::from(rhi << (F::BITS as i32 - d) != zero)) << d) + .signed(); + if neg { + i = -i; + } + r = F::cast_from(i); + } + } + + super::scalbn(r, e) +} + +/// Representation of `F` that has handled subnormals. +struct Norm { + /// Normalized significand with one guard bit. + m: F::Int, + /// Unbiased exponent, normalized. + e: i32, + neg: bool, +} + +impl Norm { + fn from_float(x: F) -> Self { + let mut ix = x.to_bits(); + let mut e = x.exp() as i32; + let neg = x.is_sign_negative(); + if e == 0 { + // Normalize subnormals by multiplication + let magic = F::from_parts(false, F::BITS - 1 + F::EXP_BIAS, F::Int::ZERO); + let scaled = x * magic; + ix = scaled.to_bits(); + e = scaled.exp() as i32; + e = if e != 0 { e - (F::BITS as i32 - 1) } else { 0x800 }; + } + + e -= F::EXP_BIAS as i32 + 52 + 1; + + ix &= F::SIG_MASK; + ix |= F::IMPLICIT_BIT; + ix <<= 1; // add a guard bit + + Self { m: ix, e, neg } + } +} + +/// Type-specific helpers that are not needed outside of fma. +pub trait FmaHelper { + fn raise_underflow(self) -> Self; + fn raise_underflow2(self) -> Self; +} + +impl FmaHelper for f64 { + fn raise_underflow(self) -> Self { + let x0_ffffff8p_63 = f64::from_bits(0x3bfffffff0000000); // 0x0.ffffff8p-63 + let fltmin: f32 = (x0_ffffff8p_63 * f32::MIN_POSITIVE as f64 * self) as f32; + f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * fltmin as f64 + } + + fn raise_underflow2(self) -> Self { + /* raise underflow portably, such that it + * cannot be optimized away */ + let tiny: f64 = f64::MIN_POSITIVE / f32::MIN_POSITIVE as f64 * self; + (tiny * tiny) * (self - self) + } +} diff --git a/src/math/generic/mod.rs b/src/math/generic/mod.rs index 68686b0b2..e19cc83a9 100644 --- a/src/math/generic/mod.rs +++ b/src/math/generic/mod.rs @@ -3,6 +3,7 @@ mod copysign; mod fabs; mod fdim; mod floor; +mod fma; mod fmax; mod fmin; mod fmod; @@ -17,6 +18,7 @@ pub use copysign::copysign; pub use fabs::fabs; pub use fdim::fdim; pub use floor::floor; +pub use fma::fma; pub use fmax::fmax; pub use fmin::fmin; pub use fmod::fmod; diff --git a/src/math/support/float_traits.rs b/src/math/support/float_traits.rs index 1fe2cb424..24cf7d4b0 100644 --- a/src/math/support/float_traits.rs +++ b/src/math/support/float_traits.rs @@ -23,7 +23,9 @@ pub trait Float: type Int: Int; /// A int of the same width as the float - type SignedInt: Int + MinInt; + type SignedInt: Int + + MinInt + + ops::Neg; const ZERO: Self; const NEG_ZERO: Self; diff --git a/src/math/support/int_traits.rs b/src/math/support/int_traits.rs index b403c658c..793a0f306 100644 --- a/src/math/support/int_traits.rs +++ b/src/math/support/int_traits.rs @@ -52,10 +52,14 @@ pub trait Int: + ops::Sub + ops::Mul + ops::Div + + ops::Shl + + ops::Shl + + ops::Shr + ops::Shr + ops::BitXor + ops::BitAnd + cmp::Ord + + From + CastFrom + CastFrom + CastFrom @@ -92,6 +96,7 @@ pub trait Int: fn wrapping_shr(self, other: u32) -> Self; fn rotate_left(self, other: u32) -> Self; fn overflowing_add(self, other: Self) -> (Self, bool); + fn overflowing_sub(self, other: Self) -> (Self, bool); fn leading_zeros(self) -> u32; fn ilog2(self) -> u32; } @@ -150,6 +155,10 @@ macro_rules! int_impl_common { ::overflowing_add(self, other) } + fn overflowing_sub(self, other: Self) -> (Self, bool) { + ::overflowing_sub(self, other) + } + fn leading_zeros(self) -> u32 { ::leading_zeros(self) } @@ -399,6 +408,30 @@ macro_rules! cast_into { )*}; } +macro_rules! cast_into_float { + ($ty:ty) => { + #[cfg(f16_enabled)] + cast_into_float!($ty; f16); + + cast_into_float!($ty; f32, f64); + + #[cfg(f128_enabled)] + cast_into_float!($ty; f128); + }; + ($ty:ty; $($into:ty),*) => {$( + impl CastInto<$into> for $ty { + fn cast(self) -> $into { + debug_assert_eq!(self as $into as $ty, self, "inexact float cast"); + self as $into + } + + fn cast_lossy(self) -> $into { + self as $into + } + } + )*}; +} + cast_into!(usize); cast_into!(isize); cast_into!(u8); @@ -411,3 +444,9 @@ cast_into!(u64); cast_into!(i64); cast_into!(u128); cast_into!(i128); + +cast_into_float!(i8); +cast_into_float!(i16); +cast_into_float!(i32); +cast_into_float!(i64); +cast_into_float!(i128);