diff --git a/src/hasbranching.jl b/src/hasbranching.jl index 073b70005..24914e601 100644 --- a/src/hasbranching.jl +++ b/src/hasbranching.jl @@ -10,8 +10,10 @@ function Cassette.overdub(ctx::HasBranchingCtx, f, args...) end end -for (mod, f, n) in DiffRules.diffrules() - isdefined(@__MODULE__, mod) || continue +for (mod, f, n) in DiffRules.diffrules(; filter_modules=nothing) + if !(isdefined(@__MODULE__, mod) && isdefined(getfield(@__MODULE__, mod), f)) + continue # Skip rules for methods not defined in the current scope + end @eval function Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), x::Vararg{Any, $n}) f(x...)