From 45432c89e86737dab6e9b8b1530edfdd6a3b887b Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Sat, 4 May 2024 13:04:07 -0400 Subject: [PATCH 01/20] Add saturating integer math --- src/OverflowContexts.jl | 3 +- src/base_ext.jl | 94 +++++++++++++++++++++++----- src/macros.jl | 56 +++++++++++++++-- test/runtests.jl | 135 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 258 insertions(+), 30 deletions(-) diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index 2701a9d..754e1ab 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -4,9 +4,10 @@ __precompile__(false) include("macros.jl") include("base_ext.jl") -export @default_checked, @default_unchecked, @checked, @unchecked, +export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating, unchecked_neg, unchecked_add, unchecked_sub, unchecked_mul, unchecked_negsub, unchecked_pow, unchecked_abs, checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs, + saturating_neg, saturating_add, saturating_sub, saturating_mul, saturating_pow, saturating_negsub, saturating_abs, SignedBitInteger end # module diff --git a/src/base_ext.jl b/src/base_ext.jl index 60941b8..ece0b46 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -9,16 +9,27 @@ if VERSION ≥ v"1.11-alpha" end # convert multi-argument calls into nested two-argument calls +checked_add(a, b, c, xs...) = @checked (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) +checked_sub(a, b, c, xs...) = @checked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) +checked_mul(a, b, c, xs...) = @checked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) + unchecked_add(a, b, c, xs...) = @unchecked (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) unchecked_sub(a, b, c, xs...) = @unchecked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) unchecked_mul(a, b, c, xs...) = @unchecked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) -checked_add(a, b, c, xs...) = @checked (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) -checked_sub(a, b, c, xs...) = @checked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) -checked_mul(a, b, c, xs...) = @checked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) +saturating_add(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) +saturating_sub(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) +saturating_mul(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) # passthrough for non-numbers +checked_neg(x) = Base.:-(x) +checked_add(x, y) = Base.:+(x, y) +checked_sub(x, y) = Base.:-(x, y) +checked_mul(x, y) = Base.:*(x, y) +checked_pow(x, y) = Base.:^(x, y) +checked_abs(x) = Base.abs(x) + unchecked_neg(x) = Base.:-(x) unchecked_add(x, y) = Base.:+(x, y) unchecked_sub(x, y) = Base.:-(x, y) @@ -26,35 +37,45 @@ unchecked_mul(x, y) = Base.:*(x, y) unchecked_pow(x, y) = Base.:^(x, y) unchecked_abs(x) = Base.abs(x) -checked_neg(x) = Base.:-(x) -checked_add(x, y) = Base.:+(x, y) -checked_sub(x, y) = Base.:-(x, y) -checked_mul(x, y) = Base.:*(x, y) -checked_pow(x, y) = Base.:^(x, y) -checked_abs(x) = Base.abs(x) +saturating_neg(x) = Base.:-(x) +saturating_add(x, y) = Base.:+(x, y) +saturating_sub(x, y) = Base.:-(x, y) +saturating_mul(x, y) = Base.:*(x, y) +saturating_pow(x, y) = Base.:^(x, y) +saturating_abs(x) = Base.abs(x) # promote unmatched number types to same type +checked_add(x::Number, y::Number) = checked_add(promote(x, y)...) +checked_sub(x::Number, y::Number) = checked_sub(promote(x, y)...) +checked_mul(x::Number, y::Number) = checked_mul(promote(x, y)...) +checked_pow(x::Number, y::Number) = checked_pow(promote(x, y)...) + unchecked_add(x::Number, y::Number) = unchecked_add(promote(x, y)...) unchecked_sub(x::Number, y::Number) = unchecked_sub(promote(x, y)...) unchecked_mul(x::Number, y::Number) = unchecked_mul(promote(x, y)...) unchecked_pow(x::Number, y::Number) = unchecked_pow(promote(x, y)...) -checked_add(x::Number, y::Number) = checked_add(promote(x, y)...) -checked_sub(x::Number, y::Number) = checked_sub(promote(x, y)...) -checked_mul(x::Number, y::Number) = checked_mul(promote(x, y)...) -checked_pow(x::Number, y::Number) = checked_pow(promote(x, y)...) +saturating_add(x::Number, y::Number) = saturating_add(promote(x, y)...) +saturating_sub(x::Number, y::Number) = saturating_sub(promote(x, y)...) +saturating_mul(x::Number, y::Number) = saturating_mul(promote(x, y)...) +saturating_pow(x::Number, y::Number) = saturating_pow(promote(x, y)...) # passthrough for same-type numbers that aren't integers +checked_add(x::T, y::T) where T <: Number = Base.:+(x, y) +checked_sub(x::T, y::T) where T <: Number = Base.:-(x, y) +checked_mul(x::T, y::T) where T <: Number = Base.:*(x, y) +checked_pow(x::T, y::T) where T <: Number = Base.:^(x, y) + unchecked_add(x::T, y::T) where T <: Number = Base.:+(x, y) unchecked_sub(x::T, y::T) where T <: Number = Base.:-(x, y) unchecked_mul(x::T, y::T) where T <: Number = Base.:*(x, y) unchecked_pow(x::T, y::T) where T <: Number = Base.:^(x, y) -checked_add(x::T, y::T) where T <: Number = Base.:+(x, y) -checked_sub(x::T, y::T) where T <: Number = Base.:-(x, y) -checked_mul(x::T, y::T) where T <: Number = Base.:*(x, y) -checked_pow(x::T, y::T) where T <: Number = Base.:^(x, y) +saturating_add(x::T, y::T) where T <: Number = Base.:+(x, y) +saturating_sub(x::T, y::T) where T <: Number = Base.:-(x, y) +saturating_mul(x::T, y::T) where T <: Number = Base.:*(x, y) +saturating_pow(x::T, y::T) where T <: Number = Base.:^(x, y) # core methods @@ -65,6 +86,45 @@ unchecked_mul(x::T, y::T) where T <: BitInteger = mul_int(x, y) unchecked_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} = power_by_squaring(x, y) unchecked_abs(x::T) where T <: SignedBitInteger = flipsign(x, x) +saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) +function saturating_add(x::T, y::T) where T <: BitInteger + result, overflow_flag = add_with_overflow(x, y) + if overflow_flag + return sign(x) > 0 ? + typemax(T) : + typemin(T) + end + return result +end +function saturating_sub(x::T, y::T) where T <: BitInteger + result, overflow_flag = sub_with_overflow(x, y) + if overflow_flag + return y > x ? + typemin(T) : + typemax(T) + end + return result +end +function saturating_mul(x::T, y::T) where T <: BitInteger + result, overflow_flag = mul_with_overflow(x, y) + if overflow_flag + return sign(x) == sign(y) ? + typemax(T) : + typemin(T) + end + return result +end +function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} + result, overflow_flag = pow_with_overflow(x, y) + if overflow_flag + return sign(x) > 0 ? + typemax(T) : + typemin(T) + end + return result +end +saturating_abs(x::T) where T <: SignedBitInteger = x == typemin(T) ? typemax(T) : flipsign(x, x) + if VERSION < v"1.11" # Base.Checked only gained checked powers in 1.11 diff --git a/src/macros.jl b/src/macros.jl index 546a456..cb689fe 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -41,6 +41,36 @@ macro default_unchecked() end end +""" + @default_saturating + +Redirect default integer math to saturating operators for the current module. Only works at top-level. +""" +macro default_saturating() + quote + any(Base.isbindingresolved.(Ref(@__MODULE__), (:+, :-, :*, :^, :abs))) && + error("A default context may only be set before any reference to the affected methods (+, -, *, ^, abs) in the target module.") + (@__MODULE__).eval(:(-(x) = OverflowContexts.saturating_neg(x))) + (@__MODULE__).eval(:(+(x...) = OverflowContexts.saturating_add(x...))) + (@__MODULE__).eval(:(-(x...) = OverflowContexts.saturating_sub(x...))) + (@__MODULE__).eval(:(*(x...) = OverflowContexts.saturating_mul(x...))) + (@__MODULE__).eval(:(^(x...) = OverflowContexts.saturating_pow(x...))) + (@__MODULE__).eval(:(abs(x) = OverflowContexts.saturating_abs(x))) + nothing + end +end + +""" + @checked expr + +Perform all integer operations in `expr` using overflow-checked arithmetic. +""" +macro checked(expr) + isa(expr, Expr) || return expr + expr = copy(expr) + return esc(replace_op!(expr, op_checked)) +end + """ @unchecked expr @@ -53,14 +83,14 @@ macro unchecked(expr) end """ - @checked expr + @saturating expr -Perform all integer operations in `expr` using overflow-checked arithmetic. +Perform all integer operations in `expr` using saturating arithmetic. """ -macro checked(expr) +macro saturating(expr) isa(expr, Expr) || return expr expr = copy(expr) - return esc(replace_op!(expr, op_checked)) + return esc(replace_op!(expr, op_saturating)) end const op_checked = Dict( @@ -91,11 +121,27 @@ const op_unchecked = Dict( :abs => :(unchecked_abs) ) +const op_saturating = Dict( + Symbol("unary-") => :(saturating_neg), + Symbol("ambig-") => :(saturating_negsub), + :+ => :(saturating_add), + :- => :(saturating_sub), + :* => :(saturating_mul), + :^ => :(saturating_pow), + :+= => :(saturating_add), + :-= => :(saturating_sub), + :*= => :(saturating_mul), + :^= => :(saturating_pow), + :abs => :(saturating_abs) +) + # resolve ambiguity when `-` used as symbol unchecked_negsub(x) = unchecked_neg(x) unchecked_negsub(x, y) = unchecked_sub(x, y) checked_negsub(x) = checked_neg(x) checked_negsub(x, y) = checked_sub(x, y) +saturating_negsub(x) = saturating_neg(x) +saturating_negsub(x, y) = saturating_sub(x, y) # copied from CheckedArithmetic.jl and modified it function replace_op!(expr::Expr, op_map::Dict) @@ -144,7 +190,7 @@ function replace_op!(expr::Expr, op_map::Dict) op = get(op_map, op, op) expr.head = :(=) expr.args[2] = Expr(:call, op, target, arg) - elseif !isexpr(expr, :macrocall) || expr.args[1] ∉ (Symbol("@checked"), Symbol("@unchecked")) + elseif !isexpr(expr, :macrocall) || expr.args[1] ∉ (Symbol("@checked"), Symbol("@unchecked"), Symbol("@saturating")) for a in expr.args if isa(a, Expr) replace_op!(a, op_map) diff --git a/test/runtests.jl b/test/runtests.jl index 5b89cc7..be2f2ab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,27 @@ end @test @unchecked(abs(typemin(Int))) == typemin(Int) end +@testset "saturating expressions" begin + @test @saturating(-typemin(Int)) == typemax(Int) + @test @saturating(-UInt(1)) == typemin(UInt) + + @test @saturating(typemax(Int) + 1) == typemax(Int) + @test @saturating(typemax(UInt) + 1) == typemax(UInt) + + @test @saturating(typemin(Int) - 1) == typemin(Int) + @test @saturating(typemin(UInt) - 1) == typemin(UInt) + + @test @saturating(typemax(Int) * 2) == typemax(Int) + @test @saturating(typemin(Int) * 2) == typemin(Int) + @test @saturating(typemax(UInt) * 2) == typemax(UInt) + + @test @saturating(typemax(Int) ^ 2) == typemax(Int) + @test @saturating(typemin(Int) ^ 2) == typemin(Int) + @test @saturating(typemax(UInt) ^ 2) == typemax(UInt) + + @test @saturating(abs(typemin(Int))) == typemax(Int) +end + @testset "juxtaposed multiplication works" begin @test_throws OverflowError @checked 2typemax(Int) @test_throws OverflowError @checked 2typemin(Int) @@ -52,38 +73,41 @@ end @test @unchecked(2typemax(Int)) == -2 @test @unchecked(2typemin(Int)) == 0 @test @unchecked(2typemax(UInt)) == typemax(UInt) - 1 + @test @saturating(2typemax(Int)) == typemax(Int) + @test @saturating(2typemin(Int)) == typemin(Int) + @test @saturating(2typemax(UInt)) == typemax(UInt) end @testset "exhaustive checks over 16 bit math" begin T = Int16 @testset "negation" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(-i) == @unchecked(-i) == -i + @test @checked(-i) == @unchecked(-i) == @saturating(-i) == -i end end @testset "addition" begin for i ∈ typemin(T):typemax(T) - Int8(1) - @test @checked(i + T(1)) == @unchecked(i + T(1)) == i + T(1) + @test @checked(i + T(1)) == @unchecked(i + T(1)) == @saturating(i + T(1)) == i + T(1) end end @testset "subtraction" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(i - T(1)) == @unchecked(i - T(1)) == i - T(1) + @test @checked(i - T(1)) == @unchecked(i - T(1)) == @saturating(i - T(1)) == i - T(1) end end @testset "multiplication" begin for i ∈ typemin(T) ÷ T(2):typemax(T) ÷ T(2) - @test @checked(2i) == @unchecked(2i) == 2i + @test @checked(2i) == @unchecked(2i) == @saturating(2i) == 2i end end @testset "power" begin for i ∈ ceil(T, -√(typemax(T))):floor(T, √(typemax(T))) - @test @checked(i ^ 2) == @unchecked(i ^ 2) == i ^ 2 + @test @checked(i ^ 2) == @unchecked(i ^ 2) == @saturating(i ^ 2) == i ^ 2 end end @testset "abs" begin for i ∈ typemin(T) + T(1):typemax(T) - @test @checked(abs(i)) == @unchecked(abs(i)) == abs(i) + @test @checked(abs(i)) == @unchecked(abs(i)) == @saturating(abs(i)) == abs(i) end end end @@ -91,38 +115,55 @@ end @testset "lowest-level macro takes priority" begin @checked begin @test @unchecked(typemax(Int) + 1) == typemin(Int) + @test @saturating(typemax(Int) + 1) == typemax(Int) end @unchecked begin @test_throws OverflowError @checked typemax(Int) + 1 + @test @saturating(typemax(Int) + 1) == typemax(Int) + end + @saturating begin + @test @unchecked(typemax(Int) + 1) == typemin(Int) + @test_throws OverflowError @checked typemax(Int) + 1 end end @testset "literals passthrough" begin @test @checked(-1) == -1 @test @unchecked(-1) == -1 + @test @saturating(-1) == -1 end @testset "non-integer math still works" begin @test @checked(-1.0) == -1 @test @unchecked(-1.0) == -1 + @test @saturating(-1.0) == -1 @test @checked(1.0 + 3.0) == 4.0 @test @unchecked(1.0 + 3.0) == 4.0 + @test @saturating(1.0 + 3.0) == 4.0 @test @checked(1 + 3.0) == 4.0 @test @unchecked(1 + 3.0) == 4.0 + @test @saturating(1 + 3.0) == 4.0 @test @checked(1.0 - 3.0) == -2.0 @test @unchecked(1.0 - 3.0) == -2.0 + @test @saturating(1.0 - 3.0) == -2.0 @test @checked(1 - 3.0) == -2.0 @test @unchecked(1 - 3.0) == -2.0 + @test @saturating(1 - 3.0) == -2.0 @test @checked(1.0 * 3.0) == 3.0 @test @unchecked(1.0 * 3.0) == 3.0 + @test @saturating(1.0 * 3.0) == 3.0 @test @checked(1 * 3.0) == 3.0 @test @unchecked(1 * 3.0) == 3.0 + @test @saturating(1 * 3.0) == 3.0 @test @checked(1.0 ^ 3.0) == 1.0 @test @unchecked(1.0 ^ 3.0) == 1.0 + @test @saturating(1.0 ^ 3.0) == 1.0 @test @checked(1 ^ 3.0) == 1.0 @test @unchecked(1 ^ 3.0) == 1.0 + @test @saturating(1 ^ 3.0) == 1.0 @test @checked(abs(-1.0)) == 1.0 @test @unchecked(abs(-1.0)) == 1.0 + @test @saturating(abs(-1.0)) == 1.0 end @testset "symbol replacement" begin @@ -132,11 +173,17 @@ end expr = @macroexpand @unchecked foldl(+, []) @test expr.args[2] == :unchecked_add + expr = @macroexpand @saturating foldl(+, []) + @test expr.args[2] == :saturating_add + expr = @macroexpand @checked foldl(-, []) @test expr.args[2] == :checked_negsub expr = @macroexpand @unchecked foldl(-, []) @test expr.args[2] == :unchecked_negsub + + expr = @macroexpand @saturating foldl(-, []) + @test expr.args[2] == :saturating_negsub expr = @macroexpand @checked foldl(*, []) @test expr.args[2] == :checked_mul @@ -144,42 +191,63 @@ end expr = @macroexpand @unchecked foldl(*, []) @test expr.args[2] == :unchecked_mul + expr = @macroexpand @saturating foldl(*, []) + @test expr.args[2] == :saturating_mul + expr = @macroexpand @checked foldl(^, []) @test expr.args[2] == :checked_pow expr = @macroexpand @unchecked foldl(^, []) @test expr.args[2] == :unchecked_pow + expr = @macroexpand @saturating foldl(^, []) + @test expr.args[2] == :saturating_pow + expr = @macroexpand @checked foldl(:abs, []) @test expr.args[2] == :checked_abs expr = @macroexpand @unchecked foldl(:abs, []) @test expr.args[2] == :unchecked_abs + + expr = @macroexpand @saturating foldl(:abs, []) + @test expr.args[2] == :saturating_abs end @testset "negsub helper methods dispatch correctly" begin + @test checked_negsub(1) == -1 + @test checked_negsub(1, 2) == 1 - 2 @test unchecked_negsub(1) == -1 @test unchecked_negsub(1, 2) == 1 - 2 + @test saturating_negsub(1) == -1 + @test saturating_negsub(1, 2) == 1 - 2 end @testset "in-place assignement" begin a = typemax(Int) @test_throws OverflowError @checked a += 1 + @saturating a += 1 + @test a == typemax(Int) @unchecked a += 1 @test a == typemin(Int) a = typemin(Int) @test_throws OverflowError @checked a -= 1 + @saturating a -= 1 + @test a == typemin(Int) @unchecked a -= 1 @test a == typemax(Int) a = typemax(Int) @test_throws OverflowError @checked a *= 2 + @saturating a *= 2 + @test a == typemax(Int) @unchecked a *= 2 @test a == -2 a = typemax(Int) @test_throws OverflowError @checked a ^= 2 + @saturating a ^= 2 + @test a == typemax(Int) @unchecked a ^= 2 @test a == 1 end @@ -189,11 +257,19 @@ end foominus(x, y) = x - y end +@saturating begin + barplus(x, y) = x + y + barminus(x, y) = x - y +end + @testset "rewrite inside block body" begin @test fooplus(0x10, 0x20) === 0x30 @test_throws OverflowError fooplus(0xf0, 0x20) @test foominus(0x30, 0x20) === 0x10 @test_throws OverflowError foominus(0x20, 0x30) + + @test barplus(0xf0, 0x20) === 0xff + @test barminus(0x20, 0x30) === 0x00 end module CheckedModule @@ -206,30 +282,67 @@ module CheckedModule @default_unchecked testfunc() = @test typemax(Int) + 1 == typemin(Int) end + + module NestedSaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + end end module UncheckedModule using OverflowContexts, Test @default_unchecked testfunc() = @test typemax(Int) + 1 == typemin(Int) + module NestedCheckedModule using OverflowContexts, Test @default_checked testfunc() = @test_throws OverflowError typemax(Int) + 1 end + + module NestedSaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + end +end + +module SaturatingModule + using OverflowContexts, Test + @default_saturating + testfunc() = @test typemax(Int) + 1 == typemax(Int) + + module NestedCheckedModule + using OverflowContexts, Test + @default_checked + testfunc() = @test_throws OverflowError typemax(Int) + 1 + end + + module NestedUncheckedModule + using OverflowContexts, Test + @default_unchecked + testfunc() = @test typemax(Int) + 1 == typemin(Int) + end end @testset "module-specific contexts" begin CheckedModule.testfunc() CheckedModule.NestedUncheckedModule.testfunc() + CheckedModule.NestedSaturatingModule.testfunc() UncheckedModule.testfunc() UncheckedModule.NestedCheckedModule.testfunc() + UncheckedModule.NestedSaturatingModule.testfunc() + SaturatingModule.testfunc() + SaturatingModule.NestedCheckedModule.testfunc() + SaturatingModule.NestedUncheckedModule.testfunc() end @testset "default methods error if Base symbol already resolved" begin x = 1 + 1 @test_throws ErrorException @default_checked @test_throws ErrorException @default_unchecked + @test_throws ErrorException @default_saturating (@__MODULE__).eval(:( module BadCheckedModule @@ -244,9 +357,17 @@ end x = 1 + 1 @test_throws ErrorException @default_unchecked end)) + + (@__MODULE__).eval(:( + module BadSaturatingModule + using OverflowContexts, Test + x = 1 + 1 + @test_throws ErrorException @default_saturating + end)) end @testset "ensure pow methods don't promote on the power" begin - @test typeof(@unchecked 3 ^ UInt(4)) == Int @test typeof(@checked 3 ^ UInt(4)) == Int + @test typeof(@unchecked 3 ^ UInt(4)) == Int + @test typeof(@saturating 3 ^ UInt(4)) == Int end From 62fccc74f6b3c8b3304ff9b00caab6f01049c301 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 11:41:09 -0400 Subject: [PATCH 02/20] Adapt to merged changes --- src/OverflowContexts.jl | 4 ++-- src/base_ext.jl | 43 ++++++++++++++++++----------------------- src/macros.jl | 9 +++++++-- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index 4f8034b..0c338b0 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -4,8 +4,8 @@ include("macros.jl") include("base_ext.jl") export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating, + checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs, unchecked_neg, unchecked_add, unchecked_sub, unchecked_mul, unchecked_negsub, unchecked_pow, unchecked_abs, - checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs - saturating_neg, saturating_add, saturating_sub, saturating_mul, saturating_pow, saturating_negsub, saturating_abs, + saturating_neg, saturating_add, saturating_sub, saturating_mul, saturating_pow, saturating_negsub, saturating_abs end # module diff --git a/src/base_ext.jl b/src/base_ext.jl index 056dcdb..2e77133 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -1,11 +1,12 @@ -import Base: promote, afoldl, @_inline_meta -import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs +import Base: BitInteger, promote, afoldl, @_inline_meta +import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs, + add_with_overflow, sub_with_overflow, mul_with_overflow if VERSION ≥ v"1.11-alpha" - import Base.Checked: checked_pow + import Base.Checked: checked_pow, pow_with_overflow else - import Base: BitInteger, throw_domerr_powbysq, to_power_type - import Base.Checked: mul_with_overflow, throw_overflowerr_binaryop + import Base: throw_domerr_powbysq, to_power_type + import Base.Checked: throw_overflowerr_binaryop end # The Base methods have unchecked semantics, so just pass through @@ -21,32 +22,23 @@ unchecked_abs(x...) = Base.abs(x...) checked_add(a, b, c, xs...) = @checked (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) checked_sub(a, b, c, xs...) = @checked (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) checked_mul(a, b, c, xs...) = @checked (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) + saturating_add(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(+, (+)((+)(a, b), c), xs...)) saturating_sub(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(-, (-)((-)(a, b), c), xs...)) saturating_mul(a, b, c, xs...) = @saturating (@_inline_meta; afoldl(*, (*)((*)(a, b), c), xs...)) -# passthrough for non-numbers -unchecked_neg(x) = Base.:-(x) -unchecked_add(x, y) = Base.:+(x, y) -unchecked_sub(x, y) = Base.:-(x, y) -unchecked_mul(x, y) = Base.:*(x, y) -unchecked_pow(x, y) = Base.:^(x, y) -unchecked_abs(x) = Base.abs(x) - -checked_neg(x) = Base.:-(x) -checked_add(x, y) = Base.:+(x, y) -checked_sub(x, y) = Base.:-(x, y) -checked_mul(x, y) = Base.:*(x, y) -checked_pow(x, y) = Base.:^(x, y) -checked_abs(x) = Base.abs(x) - # promote unmatched number types to same type checked_add(x::Number, y::Number) = checked_add(promote(x, y)...) checked_sub(x::Number, y::Number) = checked_sub(promote(x, y)...) checked_mul(x::Number, y::Number) = checked_mul(promote(x, y)...) checked_pow(x::Number, y::Number) = checked_pow(promote(x, y)...) +saturating_add(x::Number, y::Number) = saturating_add(promote(x, y)...) +saturating_sub(x::Number, y::Number) = saturating_sub(promote(x, y)...) +saturating_mul(x::Number, y::Number) = saturating_mul(promote(x, y)...) +saturating_pow(x::Number, y::Number) = saturating_pow(promote(x, y)...) + # fallback to `unchecked_` for `Number` types that don't have more specific `checked_` methods checked_neg(x::T) where T <: Number = unchecked_neg(x) @@ -56,10 +48,12 @@ checked_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y) checked_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y) checked_abs(x::T) where T <: Number = unchecked_abs(x) -saturating_add(x::T, y::T) where T <: Number = Base.:+(x, y) -saturating_sub(x::T, y::T) where T <: Number = Base.:-(x, y) -saturating_mul(x::T, y::T) where T <: Number = Base.:*(x, y) -saturating_pow(x::T, y::T) where T <: Number = Base.:^(x, y) +saturating_neg(x::T) where T <: Number = unchecked_neg(x) +saturating_add(x::T, y::T) where T <: Number = unchecked_add(x, y) +saturating_sub(x::T, y::T) where T <: Number = unchecked_sub(x, y) +saturating_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y) +saturating_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y) +saturating_abs(x::T) where T <: Number = unchecked_abs(x) # fallback to `unchecked_` for non-`Number` types checked_neg(x) = unchecked_neg(x) @@ -107,6 +101,7 @@ function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} end return result end +const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} saturating_abs(x::T) where T <: SignedBitInteger = x == typemin(T) ? typemax(T) : flipsign(x, x) if VERSION < v"1.11" diff --git a/src/macros.jl b/src/macros.jl index 19fa0d2..1666bff 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -57,14 +57,19 @@ Redirect default integer math to saturating operators for the current module. On """ macro default_saturating() quote - any(Base.isbindingresolved.(Ref(@__MODULE__), (:+, :-, :*, :^, :abs))) && - error("A default context may only be set before any reference to the affected methods (+, -, *, ^, abs) in the target module.") + if !isdefined(@__MODULE__, :__OverflowContextDefaultSet) + any(Base.isbindingresolved.(Ref(@__MODULE__), op_method_symbols)) && + error("A default context may only be set before any reference to the affected methods (+, -, *, ^, abs) in the target module.") + else + @warn "A previous default was set for this module. Previously defined methods in this module will be recompiled with this new default." + end (@__MODULE__).eval(:(-(x) = OverflowContexts.saturating_neg(x))) (@__MODULE__).eval(:(+(x...) = OverflowContexts.saturating_add(x...))) (@__MODULE__).eval(:(-(x...) = OverflowContexts.saturating_sub(x...))) (@__MODULE__).eval(:(*(x...) = OverflowContexts.saturating_mul(x...))) (@__MODULE__).eval(:(^(x...) = OverflowContexts.saturating_pow(x...))) (@__MODULE__).eval(:(abs(x) = OverflowContexts.saturating_abs(x))) + (@__MODULE__).eval(:(__OverflowContextDefaultSet = true)) nothing end end From 48c284636f5ee316503b37f758aedf071ad0a53e Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 11:49:09 -0400 Subject: [PATCH 03/20] format fix --- src/base_ext.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/base_ext.jl b/src/base_ext.jl index 2e77133..a38d293 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -55,6 +55,7 @@ saturating_mul(x::T, y::T) where T <: Number = unchecked_mul(x, y) saturating_pow(x::T, y::T) where T <: Number = unchecked_pow(x, y) saturating_abs(x::T) where T <: Number = unchecked_abs(x) + # fallback to `unchecked_` for non-`Number` types checked_neg(x) = unchecked_neg(x) checked_add(x, y) = unchecked_add(x, y) @@ -64,6 +65,7 @@ checked_pow(x, y) = unchecked_pow(x, y) checked_abs(x) = unchecked_abs(x) +# saturating implementations saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) function saturating_add(x::T, y::T) where T <: BitInteger result, overflow_flag = add_with_overflow(x, y) From 34511ad951bfcb20c6f7d282ac8cb8a7c4c618f1 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 11:51:03 -0400 Subject: [PATCH 04/20] Stop doubling up CI runs --- .github/workflows/main.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 350c19e..35dae21 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,7 +1,8 @@ name: ci on: - push: + branches: + - main pull_request: jobs: From ae9b6b78dd6b9f34cd2e8cab848021f65243142f Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 11:57:30 -0400 Subject: [PATCH 05/20] Fix workflow --- .github/workflows/main.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 35dae21..0a49e67 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,8 +1,9 @@ name: ci on: - branches: - - main + pull: + branches: + - main pull_request: jobs: From 43736e41deb56106f74f89100f2ac9e4e8276e9d Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 12:44:01 -0400 Subject: [PATCH 06/20] Add widen/clamp versions for add/sub smaller integers --- src/base_ext.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index a38d293..5cd96d7 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -66,8 +66,12 @@ checked_abs(x) = unchecked_abs(x) # saturating implementations +# widen/clamp reduces to a saturating intrinsic on LLVM for signed integers through 64 bits for +/- +# for unsigned it does not right now (Julia 1.11), but it is still faster than using the with_overflow methods saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) -function saturating_add(x::T, y::T) where T <: BitInteger +saturating_add(x::T, y::T) where T <: BitInteger = + clamp(widen(x) + widen(y), T) +function saturating_add(x::T, y::T) where T <: Union{Int128, UInt128} result, overflow_flag = add_with_overflow(x, y) if overflow_flag return sign(x) > 0 ? @@ -76,7 +80,9 @@ function saturating_add(x::T, y::T) where T <: BitInteger end return result end -function saturating_sub(x::T, y::T) where T <: BitInteger +saturating_sub(x::T, y::T) where T <: BitInteger = + clamp(widen(x) - widen(y), T) +function saturating_sub(x::T, y::T) where T <: Union{Int128, UInt128} result, overflow_flag = sub_with_overflow(x, y) if overflow_flag return y > x ? From 9524143336dc492fa73a159c5128e1646982f9d1 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 12:48:16 -0400 Subject: [PATCH 07/20] Add comment --- src/base_ext.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/base_ext.jl b/src/base_ext.jl index 5cd96d7..1ae593b 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -68,6 +68,7 @@ checked_abs(x) = unchecked_abs(x) # saturating implementations # widen/clamp reduces to a saturating intrinsic on LLVM for signed integers through 64 bits for +/- # for unsigned it does not right now (Julia 1.11), but it is still faster than using the with_overflow methods +# But we don't want to widen into a BigInt, so we use the naive approach for Int128 saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) saturating_add(x::T, y::T) where T <: BitInteger = clamp(widen(x) + widen(y), T) From 08b18273e2708059e63b2bfcce49047e8a62d3d5 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 23:53:05 -0400 Subject: [PATCH 08/20] Better saturating performance w/ intrinsics --- src/base_ext.jl | 71 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index 1ae593b..e4267f8 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -70,28 +70,50 @@ checked_abs(x) = unchecked_abs(x) # for unsigned it does not right now (Julia 1.11), but it is still faster than using the with_overflow methods # But we don't want to widen into a BigInt, so we use the naive approach for Int128 saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) -saturating_add(x::T, y::T) where T <: BitInteger = - clamp(widen(x) + widen(y), T) -function saturating_add(x::T, y::T) where T <: Union{Int128, UInt128} - result, overflow_flag = add_with_overflow(x, y) - if overflow_flag - return sign(x) > 0 ? - typemax(T) : - typemin(T) - end - return result -end -saturating_sub(x::T, y::T) where T <: BitInteger = - clamp(widen(x) - widen(y), T) -function saturating_sub(x::T, y::T) where T <: Union{Int128, UInt128} - result, overflow_flag = sub_with_overflow(x, y) - if overflow_flag - return y > x ? - typemin(T) : - typemax(T) - end - return result -end + +using Base: llvmcall +saturating_add(x::Int8, y::Int8) = + ccall("llvm.sadd.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) +saturating_add(x::Int16, y::Int16) = + ccall("llvm.sadd.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) +saturating_add(x::Int32, y::Int32) = + ccall("llvm.sadd.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) +saturating_add(x::Int64, y::Int64) = + ccall("llvm.sadd.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) +saturating_add(x::Int128, y::Int128) = + ccall("llvm.sadd.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) +saturating_add(x::UInt8, y::UInt8) = + ccall("llvm.uadd.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) +saturating_add(x::UInt16, y::UInt16) = + ccall("llvm.uadd.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) +saturating_add(x::UInt32, y::UInt32) = + ccall("llvm.uadd.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) +saturating_add(x::UInt64, y::UInt64) = + ccall("llvm.uadd.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) +saturating_add(x::UInt128, y::UInt128) = + ccall("llvm.uadd.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) + +saturating_sub(x::Int8, y::Int8) = + ccall("llvm.ssub.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) +saturating_sub(x::Int16, y::Int16) = + ccall("llvm.ssub.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) +saturating_sub(x::Int32, y::Int32) = + ccall("llvm.ssub.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) +saturating_sub(x::Int64, y::Int64) = + ccall("llvm.ssub.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) +saturating_sub(x::Int128, y::Int128) = + ccall("llvm.ssub.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) +saturating_sub(x::UInt8, y::UInt8) = + ccall("llvm.usub.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) +saturating_sub(x::UInt16, y::UInt16) = + ccall("llvm.usub.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) +saturating_sub(x::UInt32, y::UInt32) = + ccall("llvm.usub.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) +saturating_sub(x::UInt64, y::UInt64) = + ccall("llvm.usub.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) +saturating_sub(x::UInt128, y::UInt128) = + ccall("llvm.usub.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) + function saturating_mul(x::T, y::T) where T <: BitInteger result, overflow_flag = mul_with_overflow(x, y) if overflow_flag @@ -111,7 +133,10 @@ function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} return result end const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} -saturating_abs(x::T) where T <: SignedBitInteger = x == typemin(T) ? typemax(T) : flipsign(x, x) +function saturating_abs(x::T) where T <: SignedBitInteger + result = flipsign(x, x) + return result < 0 ? typemax(T) : result +end if VERSION < v"1.11" # Base.Checked only gained checked powers in 1.11 From c9c46d6c6342cbca2efece9a830d49996c770cf6 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Tue, 7 May 2024 23:53:32 -0400 Subject: [PATCH 09/20] Remove old comment --- src/base_ext.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index e4267f8..42b24ed 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -66,9 +66,6 @@ checked_abs(x) = unchecked_abs(x) # saturating implementations -# widen/clamp reduces to a saturating intrinsic on LLVM for signed integers through 64 bits for +/- -# for unsigned it does not right now (Julia 1.11), but it is still faster than using the with_overflow methods -# But we don't want to widen into a BigInt, so we use the naive approach for Int128 saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) using Base: llvmcall From 8906ac54fae2d1b43acb37441799ddf0b693d642 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:10:00 -0400 Subject: [PATCH 10/20] Restore fallbacks for Julia 1 through 1.4 --- src/OverflowContexts.jl | 1 + src/base_ext.jl | 74 +------------------------------- src/base_ext_sat.jl | 95 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 72 deletions(-) create mode 100644 src/base_ext_sat.jl diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index 0c338b0..bd2912d 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -2,6 +2,7 @@ module OverflowContexts include("macros.jl") include("base_ext.jl") +include("base_sat.jl") export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating, checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs, diff --git a/src/base_ext.jl b/src/base_ext.jl index 42b24ed..29ab61e 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -1,9 +1,9 @@ import Base: BitInteger, promote, afoldl, @_inline_meta import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked_abs, - add_with_overflow, sub_with_overflow, mul_with_overflow + mul_with_overflow if VERSION ≥ v"1.11-alpha" - import Base.Checked: checked_pow, pow_with_overflow + import Base.Checked: checked_pow else import Base: throw_domerr_powbysq, to_power_type import Base.Checked: throw_overflowerr_binaryop @@ -65,76 +65,6 @@ checked_pow(x, y) = unchecked_pow(x, y) checked_abs(x) = unchecked_abs(x) -# saturating implementations -saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) - -using Base: llvmcall -saturating_add(x::Int8, y::Int8) = - ccall("llvm.sadd.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) -saturating_add(x::Int16, y::Int16) = - ccall("llvm.sadd.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) -saturating_add(x::Int32, y::Int32) = - ccall("llvm.sadd.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) -saturating_add(x::Int64, y::Int64) = - ccall("llvm.sadd.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) -saturating_add(x::Int128, y::Int128) = - ccall("llvm.sadd.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) -saturating_add(x::UInt8, y::UInt8) = - ccall("llvm.uadd.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) -saturating_add(x::UInt16, y::UInt16) = - ccall("llvm.uadd.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) -saturating_add(x::UInt32, y::UInt32) = - ccall("llvm.uadd.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) -saturating_add(x::UInt64, y::UInt64) = - ccall("llvm.uadd.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) -saturating_add(x::UInt128, y::UInt128) = - ccall("llvm.uadd.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) - -saturating_sub(x::Int8, y::Int8) = - ccall("llvm.ssub.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) -saturating_sub(x::Int16, y::Int16) = - ccall("llvm.ssub.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) -saturating_sub(x::Int32, y::Int32) = - ccall("llvm.ssub.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) -saturating_sub(x::Int64, y::Int64) = - ccall("llvm.ssub.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) -saturating_sub(x::Int128, y::Int128) = - ccall("llvm.ssub.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) -saturating_sub(x::UInt8, y::UInt8) = - ccall("llvm.usub.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) -saturating_sub(x::UInt16, y::UInt16) = - ccall("llvm.usub.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) -saturating_sub(x::UInt32, y::UInt32) = - ccall("llvm.usub.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) -saturating_sub(x::UInt64, y::UInt64) = - ccall("llvm.usub.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) -saturating_sub(x::UInt128, y::UInt128) = - ccall("llvm.usub.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) - -function saturating_mul(x::T, y::T) where T <: BitInteger - result, overflow_flag = mul_with_overflow(x, y) - if overflow_flag - return sign(x) == sign(y) ? - typemax(T) : - typemin(T) - end - return result -end -function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} - result, overflow_flag = pow_with_overflow(x, y) - if overflow_flag - return sign(x) > 0 ? - typemax(T) : - typemin(T) - end - return result -end -const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} -function saturating_abs(x::T) where T <: SignedBitInteger - result = flipsign(x, x) - return result < 0 ? typemax(T) : result -end - if VERSION < v"1.11" # Base.Checked only gained checked powers in 1.11 diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl new file mode 100644 index 0000000..aedfe26 --- /dev/null +++ b/src/base_ext_sat.jl @@ -0,0 +1,95 @@ +# saturating implementations + +saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) + +using Base: llvmcall +if VERSION ≥ v"1.5" + # These intrinsics were added in LLVM 8, which was first supported with Julia 1.5 + saturating_add(x::Int8, y::Int8) = + ccall("llvm.sadd.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) + saturating_add(x::Int16, y::Int16) = + ccall("llvm.sadd.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) + saturating_add(x::Int32, y::Int32) = + ccall("llvm.sadd.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) + saturating_add(x::Int64, y::Int64) = + ccall("llvm.sadd.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) + saturating_add(x::Int128, y::Int128) = + ccall("llvm.sadd.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) + saturating_add(x::UInt8, y::UInt8) = + ccall("llvm.uadd.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) + saturating_add(x::UInt16, y::UInt16) = + ccall("llvm.uadd.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) + saturating_add(x::UInt32, y::UInt32) = + ccall("llvm.uadd.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) + saturating_add(x::UInt64, y::UInt64) = + ccall("llvm.uadd.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) + saturating_add(x::UInt128, y::UInt128) = + ccall("llvm.uadd.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) + + saturating_sub(x::Int8, y::Int8) = + ccall("llvm.ssub.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) + saturating_sub(x::Int16, y::Int16) = + ccall("llvm.ssub.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) + saturating_sub(x::Int32, y::Int32) = + ccall("llvm.ssub.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) + saturating_sub(x::Int64, y::Int64) = + ccall("llvm.ssub.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) + saturating_sub(x::Int128, y::Int128) = + ccall("llvm.ssub.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) + saturating_sub(x::UInt8, y::UInt8) = + ccall("llvm.usub.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) + saturating_sub(x::UInt16, y::UInt16) = + ccall("llvm.usub.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) + saturating_sub(x::UInt32, y::UInt32) = + ccall("llvm.usub.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) + saturating_sub(x::UInt64, y::UInt64) = + ccall("llvm.usub.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) + saturating_sub(x::UInt128, y::UInt128) = + ccall("llvm.usub.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) +else + import Base.Checked: add_with_overflow, sub_with_overflow, mul_with_overflow + + function saturating_add(x::T, y::T) where T <: Union{Int128, UInt128} + result, overflow_flag = add_with_overflow(x, y) + if overflow_flag + return sign(x) > 0 ? + typemax(T) : + typemin(T) + end + return result + end + + function saturating_sub(x::T, y::T) where T <: Union{Int128, UInt128} + result, overflow_flag = sub_with_overflow(x, y) + if overflow_flag + return y > x ? + typemin(T) : + typemax(T) + end + return result + end +end + +function saturating_mul(x::T, y::T) where T <: BitInteger + result, overflow_flag = mul_with_overflow(x, y) + if overflow_flag + return sign(x) == sign(y) ? + typemax(T) : + typemin(T) + end + return result +end +function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} + result, overflow_flag = pow_with_overflow(x, y) + if overflow_flag + return sign(x) > 0 ? + typemax(T) : + typemin(T) + end + return result +end +const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} +function saturating_abs(x::T) where T <: SignedBitInteger + result = flipsign(x, x) + return result < 0 ? typemax(T) : result +end From 0da4c61567b1f98923a5864ce45e8aec88e5f881 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:19:38 -0400 Subject: [PATCH 11/20] Fix include --- src/OverflowContexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/OverflowContexts.jl b/src/OverflowContexts.jl index bd2912d..afcb9a9 100644 --- a/src/OverflowContexts.jl +++ b/src/OverflowContexts.jl @@ -2,7 +2,7 @@ module OverflowContexts include("macros.jl") include("base_ext.jl") -include("base_sat.jl") +include("base_ext_sat.jl") export @default_checked, @default_unchecked, @default_saturating, @checked, @unchecked, @saturating, checked_neg, checked_add, checked_sub, checked_mul, checked_pow, checked_negsub, checked_abs, From 1b448906b4fbd96d78170f20018a860532c6dbef Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:20:14 -0400 Subject: [PATCH 12/20] Correct backport of Julia 1.11's checked_pow implementation --- src/base_ext.jl | 37 ++++++++++++------------------------- src/base_ext_sat.jl | 6 +++--- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index 29ab61e..20af79d 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -68,50 +68,37 @@ checked_abs(x) = unchecked_abs(x) if VERSION < v"1.11" # Base.Checked only gained checked powers in 1.11 -function checked_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} - @_inline_meta - z, b = pow_with_overflow(x, y) - b && throw_overflowerr_binaryop(:^, x, y) - z -end +checked_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = + power_by_squaring(x_, p; mul = checked_mul) -function pow_with_overflow(x_, p::Integer) +Base.@assume_effects :terminates_locally function power_by_squaring(x_, p::Integer; mul=*) x = to_power_type(x_) if p == 1 - return (copy(x), false) + return copy(x) elseif p == 0 - return (one(x), false) + return one(x) elseif p == 2 - return mul_with_overflow(x, x) + return mul(x, x) elseif p < 0 - isone(x) && return (copy(x), false) - isone(-x) && return (iseven(p) ? one(x) : copy(x), false) + isone(x) && return copy(x) + isone(-x) && return iseven(p) ? one(x) : copy(x) throw_domerr_powbysq(x, p) end t = trailing_zeros(p) + 1 p >>= t - b = false while (t -= 1) > 0 - x, b1 = mul_with_overflow(x, x) - b |= b1 + x = mul(x, x) end y = x while p > 0 t = trailing_zeros(p) + 1 p >>= t while (t -= 1) >= 0 - x, b1 = mul_with_overflow(x, x) - b |= b1 + x = mul(x, x) end - y, b1 = mul_with_overflow(y, x) - b |= b1 + y = mul(y, x) end - return y, b -end -pow_with_overflow(x::Bool, p::Unsigned) = ((p==0) | x, false) -function pow_with_overflow(x::Bool, p::Integer) - p < 0 && !x && throw_domerr_powbysq(x, p) - return (p==0) | x, false + return y end end diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index aedfe26..b926d75 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -79,10 +79,10 @@ function saturating_mul(x::T, y::T) where T <: BitInteger end return result end -function saturating_pow(x::T, y::S) where {T <: BitInteger, S <: BitInteger} - result, overflow_flag = pow_with_overflow(x, y) +function saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} + result, overflow_flag = power_by_squaring(x_, p; mul = saturating_mul) if overflow_flag - return sign(x) > 0 ? + return sign(x_) > 0 ? typemax(T) : typemin(T) end From 7d871665563c227297af35222578673ce0111e61 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:24:57 -0400 Subject: [PATCH 13/20] assume_effects not supported before 1.8 --- src/base_ext.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index 20af79d..4e3c3ea 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -71,7 +71,8 @@ if VERSION < v"1.11" checked_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = power_by_squaring(x_, p; mul = checked_mul) -Base.@assume_effects :terminates_locally function power_by_squaring(x_, p::Integer; mul=*) +# Base.@assume_effects :terminates_locally # present in Julia 1.11 code, but only supported from 1.8 on +function power_by_squaring(x_, p::Integer; mul=*) x = to_power_type(x_) if p == 1 return copy(x) From e7a9a3d77f4cbfc45d68ef1d73fc6a8c9dc1c0e6 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:25:07 -0400 Subject: [PATCH 14/20] tidier --- src/base_ext_sat.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index b926d75..ad493ca 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -2,8 +2,9 @@ saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) -using Base: llvmcall if VERSION ≥ v"1.5" + using Base: llvmcall + # These intrinsics were added in LLVM 8, which was first supported with Julia 1.5 saturating_add(x::Int8, y::Int8) = ccall("llvm.sadd.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) From 47189afdddad7b1746e2fb797b8c7b658b06ede2 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:25:48 -0400 Subject: [PATCH 15/20] Fix saturating_pow --- src/base_ext_sat.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index ad493ca..27f1985 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -80,15 +80,8 @@ function saturating_mul(x::T, y::T) where T <: BitInteger end return result end -function saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} - result, overflow_flag = power_by_squaring(x_, p; mul = saturating_mul) - if overflow_flag - return sign(x_) > 0 ? - typemax(T) : - typemin(T) - end - return result -end +saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = + power_by_squaring(x_, p; mul = saturating_mul) const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} function saturating_abs(x::T) where T <: SignedBitInteger result = flipsign(x, x) From 722606103e1a0c393e49212270fdcaf2015c6e05 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:32:53 -0400 Subject: [PATCH 16/20] Fix bad sat test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 84afdeb..d34bfdb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,7 +60,7 @@ end @test @saturating(typemax(UInt) * 2) == typemax(UInt) @test @saturating(typemax(Int) ^ 2) == typemax(Int) - @test @saturating(typemin(Int) ^ 2) == typemin(Int) + @test @saturating(typemin(Int) ^ 2) == typemax(Int) @test @saturating(typemax(UInt) ^ 2) == typemax(UInt) @test @saturating(abs(typemin(Int))) == typemax(Int) From 78781444bd77b8999ebb87ec7c19dbbce13a5ba8 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:33:24 -0400 Subject: [PATCH 17/20] Widen types for old-version implementations --- src/base_ext_sat.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index 27f1985..4b1e615 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -50,7 +50,7 @@ if VERSION ≥ v"1.5" else import Base.Checked: add_with_overflow, sub_with_overflow, mul_with_overflow - function saturating_add(x::T, y::T) where T <: Union{Int128, UInt128} + function saturating_add(x::T, y::T) where T <: BitInteger result, overflow_flag = add_with_overflow(x, y) if overflow_flag return sign(x) > 0 ? @@ -60,7 +60,7 @@ else return result end - function saturating_sub(x::T, y::T) where T <: Union{Int128, UInt128} + function saturating_sub(x::T, y::T) where T <: BitInteger result, overflow_flag = sub_with_overflow(x, y) if overflow_flag return y > x ? From 19b4e80c389f89bcd17b454f8769d130ef85454d Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Wed, 8 May 2024 00:38:46 -0400 Subject: [PATCH 18/20] fix imports --- src/base_ext.jl | 1 + src/base_ext_sat.jl | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/base_ext.jl b/src/base_ext.jl index 4e3c3ea..2b3d3c2 100644 --- a/src/base_ext.jl +++ b/src/base_ext.jl @@ -3,6 +3,7 @@ import Base.Checked: checked_neg, checked_add, checked_sub, checked_mul, checked mul_with_overflow if VERSION ≥ v"1.11-alpha" + import Base: power_by_squaring import Base.Checked: checked_pow else import Base: throw_domerr_powbysq, to_power_type diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index 4b1e615..e9929c9 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -1,3 +1,10 @@ +import Base: BitInteger +import Base.Checked: mul_with_overflow + +if VERSION ≤ v"1.11-alpha" + import Base: power_by_squaring +end + # saturating implementations saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) @@ -48,7 +55,7 @@ if VERSION ≥ v"1.5" saturating_sub(x::UInt128, y::UInt128) = ccall("llvm.usub.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) else - import Base.Checked: add_with_overflow, sub_with_overflow, mul_with_overflow + import Base.Checked: add_with_overflow, sub_with_overflow function saturating_add(x::T, y::T) where T <: BitInteger result, overflow_flag = add_with_overflow(x, y) From 8c5de8c2d153671568d992a8223571f895eb5be2 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Fri, 10 May 2024 17:53:38 -0400 Subject: [PATCH 19/20] Switch to generated functions --- src/base_ext_sat.jl | 53 +++++++++++---------------------------------- 1 file changed, 13 insertions(+), 40 deletions(-) diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index e9929c9..ff84a19 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -13,47 +13,20 @@ if VERSION ≥ v"1.5" using Base: llvmcall # These intrinsics were added in LLVM 8, which was first supported with Julia 1.5 - saturating_add(x::Int8, y::Int8) = - ccall("llvm.sadd.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) - saturating_add(x::Int16, y::Int16) = - ccall("llvm.sadd.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) - saturating_add(x::Int32, y::Int32) = - ccall("llvm.sadd.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) - saturating_add(x::Int64, y::Int64) = - ccall("llvm.sadd.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) - saturating_add(x::Int128, y::Int128) = - ccall("llvm.sadd.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) - saturating_add(x::UInt8, y::UInt8) = - ccall("llvm.uadd.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) - saturating_add(x::UInt16, y::UInt16) = - ccall("llvm.uadd.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) - saturating_add(x::UInt32, y::UInt32) = - ccall("llvm.uadd.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) - saturating_add(x::UInt64, y::UInt64) = - ccall("llvm.uadd.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) - saturating_add(x::UInt128, y::UInt128) = - ccall("llvm.uadd.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) + @generated function saturating_add(x::T, y::T) where T <: BitInteger + llvm_su = T <: Signed ? "s" : "u" + llvm_t = "i" * string(8sizeof(T)) + llvm_intrinsic = "llvm.$(llvm_su)add.sat.$llvm_t" + :(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y)) + end + + @generated function saturating_sub(x::T, y::T) where T <: BitInteger + llvm_su = T <: Signed ? "s" : "u" + llvm_t = "i" * string(8sizeof(T)) + llvm_intrinsic = "llvm.$(llvm_su)sub.sat.$llvm_t" + :(ccall($llvm_intrinsic, llvmcall, $T, ($T, $T), x, y)) + end - saturating_sub(x::Int8, y::Int8) = - ccall("llvm.ssub.sat.i8", llvmcall, Int8, (Int8, Int8), x, y) - saturating_sub(x::Int16, y::Int16) = - ccall("llvm.ssub.sat.i16", llvmcall, Int16, (Int16, Int16), x, y) - saturating_sub(x::Int32, y::Int32) = - ccall("llvm.ssub.sat.i32", llvmcall, Int32, (Int32, Int32), x, y) - saturating_sub(x::Int64, y::Int64) = - ccall("llvm.ssub.sat.i64", llvmcall, Int64, (Int64, Int64), x, y) - saturating_sub(x::Int128, y::Int128) = - ccall("llvm.ssub.sat.i128", llvmcall, Int128, (Int128, Int128), x, y) - saturating_sub(x::UInt8, y::UInt8) = - ccall("llvm.usub.sat.i8", llvmcall, UInt8, (UInt8, UInt8), x, y) - saturating_sub(x::UInt16, y::UInt16) = - ccall("llvm.usub.sat.i16", llvmcall, UInt16, (UInt16, UInt16), x, y) - saturating_sub(x::UInt32, y::UInt32) = - ccall("llvm.usub.sat.i32", llvmcall, UInt32, (UInt32, UInt32), x, y) - saturating_sub(x::UInt64, y::UInt64) = - ccall("llvm.usub.sat.i64", llvmcall, UInt64, (UInt64, UInt64), x, y) - saturating_sub(x::UInt128, y::UInt128) = - ccall("llvm.usub.sat.i128", llvmcall, UInt128, (UInt128, UInt128), x, y) else import Base.Checked: add_with_overflow, sub_with_overflow From 311cefdf58e180261afe420fff2fe695968adc84 Mon Sep 17 00:00:00 2001 From: Nicholas Bauer Date: Thu, 23 May 2024 11:31:50 -0400 Subject: [PATCH 20/20] style adjustment --- src/base_ext_sat.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/base_ext_sat.jl b/src/base_ext_sat.jl index ff84a19..102bcd2 100644 --- a/src/base_ext_sat.jl +++ b/src/base_ext_sat.jl @@ -6,6 +6,7 @@ if VERSION ≤ v"1.11-alpha" end # saturating implementations +const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} saturating_neg(x::T) where T <: BitInteger = saturating_sub(zero(T), x) @@ -53,16 +54,16 @@ end function saturating_mul(x::T, y::T) where T <: BitInteger result, overflow_flag = mul_with_overflow(x, y) - if overflow_flag - return sign(x) == sign(y) ? + return overflow_flag ? + (sign(x) == sign(y) ? typemax(T) : - typemin(T) - end - return result + typemin(T)) : + result end + saturating_pow(x_::T, p::S) where {T <: BitInteger, S <: BitInteger} = power_by_squaring(x_, p; mul = saturating_mul) -const SignedBitInteger = Union{Int8, Int16, Int32, Int64, Int128} + function saturating_abs(x::T) where T <: SignedBitInteger result = flipsign(x, x) return result < 0 ? typemax(T) : result