From a6e916e05c5e4a9b9ffe5b4dd5819a1726b96da1 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 8 Jan 2025 21:53:15 -0100 Subject: [PATCH 1/3] Specialize NaNMath.pow to allow for ::Int ^ ::Int -> ::Int NaNMath does not handle integer dispatches. This fixes it so that a wrapper function handles the integer dispatch, restoring the behavior of ^ --- Project.toml | 2 +- src/SymbolicUtils.jl | 6 ++++++ src/code.jl | 4 ++-- src/methods.jl | 2 +- test/fuzzlib.jl | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index f53d33729..7a5bbf39c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "3.8.1" +version = "3.8.2" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index fb13f50b4..3e707228c 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -30,6 +30,12 @@ include("types.jl") # Methods on symbolic objects using SpecialFunctions, NaNMath + +# NaNMath.pow does not handle x::Int ^ y::Int -> ::Int +# Use this instead as a wrapper over NaNMath.pow +pow(x,y) = NaNMath.pow(x,y) +pow(x::Int, y::Int) = x^y + import IfElse: ifelse # need to not bring IfElse name in or it will clash with Rewriters.IfElse include("methods.jl") diff --git a/src/code.jl b/src/code.jl index 8b1953b20..d7f3576d5 100644 --- a/src/code.jl +++ b/src/code.jl @@ -146,12 +146,12 @@ function function_to_expr(op::typeof(^), O, st) return toexpr(Term(inv, Any[ex]), st) else args = Any[Term(inv, Any[ex]), -args[2]] - op = get(st.rewrites, :nanmath, false) ? op : NaNMath.pow + op = get(st.rewrites, :nanmath, false) ? op : pow return toexpr(Term(op, args), st) end end get(st.rewrites, :nanmath, false) === true || return nothing - return toexpr(Term(NaNMath.pow, args), st) + return toexpr(Term(pow, args), st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) diff --git a/src/methods.jl b/src/methods.jl index 2baef6424..665e65321 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -20,7 +20,7 @@ const monadic = [deg2rad, rad2deg, transpose, asind, log1p, acsch, const diadic = [max, min, hypot, atan, NaNMath.atanh, mod, rem, copysign, besselj, bessely, besseli, besselk, hankelh1, hankelh2, - polygamma, beta, logbeta, NaNMath.pow] + polygamma, beta, logbeta, pow] const previously_declared_for = Set([]) const basic_monadic = [-, +] diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index 163467d6a..7a495a794 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -42,7 +42,7 @@ const num_spec = let ()->rand([a b c d e f])] binops = SymbolicUtils.diadic - nopow = setdiff(binops, [(^), NaNMath.pow, besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk]) + nopow = setdiff(binops, [(^), SymbolicUtils.pow, besselj0, besselj1, bessely0, bessely1, besselj, bessely, besseli, besselk]) twoargfns = vcat(nopow, (x,y)->x isa Union{Int, Rational, Complex{<:Rational}} ? x * y : x^y) fns = vcat(1 .=> vcat(SymbolicUtils.monadic, [one, zero]), 2 .=> vcat(twoargfns, fill(+, 5), [-,-], fill(*, 5), fill(/, 40)), From 48c7e0b07da55359f6143359035a196e1afe505b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 8 Jan 2025 22:11:30 -0100 Subject: [PATCH 2/3] Update src/code.jl Co-authored-by: David Widmann --- src/code.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index d7f3576d5..7d9842938 100644 --- a/src/code.jl +++ b/src/code.jl @@ -146,7 +146,7 @@ function function_to_expr(op::typeof(^), O, st) return toexpr(Term(inv, Any[ex]), st) else args = Any[Term(inv, Any[ex]), -args[2]] - op = get(st.rewrites, :nanmath, false) ? op : pow + op = get(st.rewrites, :nanmath, false) === true ? pow : op return toexpr(Term(op, args), st) end end From 6a608e0ea7670560cdd9c80d604e2e4300ee47e6 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Wed, 8 Jan 2025 22:25:49 -0100 Subject: [PATCH 3/3] import into code module --- src/code.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index d7f3576d5..39692171f 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, sorted_arguments, metadata, isterm, term, maketerm + symtype, sorted_arguments, metadata, isterm, term, maketerm, pow import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==##