From 4aaa97ff729a4f3b4027e48cfdd8b55640cda4cc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Jul 2021 22:02:27 +0100 Subject: [PATCH] Don't return opted out rules --- Project.toml | 4 ++-- src/ruleset_loading.jl | 27 +++++++++++++++++++-------- test/ruleset_loading.jl | 33 +++++++++++++++++++++++++++------ test/runtests.jl | 6 ++++-- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index edde3ec..7f9ffbc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,12 @@ name = "ChainRulesOverloadGeneration" uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" -version = "0.1.3" +version = "0.1.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [compat] -ChainRulesCore = "0.10.4" +ChainRulesCore = "1.0.0" julia = "1" [extras] diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index b57cab3..b851399 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -48,27 +48,38 @@ If you previously wrong an incorrect hook, you can use this to get rid of the ol """ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) +########################################################################################### + """ _rule_list(frule | rrule) Returns a list of all the methods of the currently defined rules of the given kind. -Excluding the fallback rule that returns `nothing` for every input; -and excluding rules that require a particular `RuleConfig`. +Excluding the fallback rule (that return `nothing` for every input) and `@opt_out` opted out +rules, and excluding rules that require a particular `RuleConfig`. """ -function _rule_list(rule_kind) +function _rule_list(rule_kind::Union{typeof(frule), typeof(rrule)}) + opted_out = Set(arg_type_tuple(m.sig) for m in _no_rule_list(rule_kind)) return Iterators.filter(methods(rule_kind)) do m - return !_is_fallback(rule_kind, m) && !_requires_config(m) + return !_requires_config(m) && arg_type_tuple(m.sig) ∉ opted_out end end -"check if this is the fallback-frule/rrule that always returns `nothing`" -_is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}} -_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}} - "check if this rule requires a particular configuation (`RuleConfig`)" _requires_config(m::Method) = m.sig <: Tuple{Any, RuleConfig, Vararg} +_no_rule_list(::typeof(rrule)) = methods(ChainRulesCore.no_rrule) +_no_rule_list(::typeof(frule)) = methods(ChainRulesCore.no_frule) + +arg_type_tuple(d::DataType) = Tuple{d.parameters[2:end]...} +function arg_type_tuple(d::UnionAll) + body = Base.unwrap_unionall(d) + body_tt = arg_type_tuple(body) + return Base.rewrap_unionall(body_tt, d) +end + +###################################################################### + const LAST_REFRESH_RRULE = Ref(0) const LAST_REFRESH_FRULE = Ref(0) last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 4e261a3..368e872 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -70,12 +70,6 @@ end end - @testset "_is_fallback" begin - _is_fallback = ChainRulesOverloadGeneration._is_fallback - @test _is_fallback(rrule, first(methods(rrule, (Nothing,)))) - @test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,)))) - end - @testset "_rule_list" begin _rule_list = ChainRulesOverloadGeneration._rule_list @testset "should not have frules that need RuleConfig" begin @@ -112,5 +106,32 @@ # Above would error if we were not handling UnionAll's right end end + + + @testset "opting out" begin + oa_id(x, y) = x + @scalar_rule(oa_id(x::Number), 1) + @opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Float32) + @opt_out ChainRulesCore.frule(::Any, ::typeof(oa_id), x::Float32) + + # In theses tests we `@assert` the behavour that `methods` has + # and then `@test` that `_rule_list` differs from that, in the way we want + + @test !isempty([m for m in _rule_list(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Number}]) + # Opted out + @assert !isempty([m for m in methods(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Float32}]) + @test isempty([m for m in _rule_list(rrule) if m.sig <: Tuple{Any,typeof(oa_id),Float32}]) + # fallback + @test !isempty([m for m in methods(rrule) if m.sig == Tuple{typeof(rrule),Any,Vararg{Any}}]) + @test isempty([m for m in _rule_list(rrule) if m.sig == Tuple{typeof(rrule),Any,Vararg{Any}}]) + + @test !isempty([m for m in _rule_list(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Number}]) + # Opted out + @assert !isempty([m for m in methods(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Float32}]) + @test isempty([m for m in _rule_list(frule) if m.sig <: Tuple{Any,Any,typeof(oa_id),Float32}]) + # fallback + @assert !isempty([m for m in methods(frule) if m.sig == Tuple{typeof(frule),Any,Any,Vararg{Any}}]) + @test isempty([m for m in _rule_list(frule) if m.sig == Tuple{typeof(frule),Any,Any,Vararg{Any}}]) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 8795d8f..ab0db53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,10 +4,12 @@ using ChainRulesOverloadGeneration using Test @testset "ChainRulesCore" begin - include("ruleset_loading.jl") - @testset "demos" begin include("demos/forwarddiffzero.jl") include("demos/reversediffzero.jl") end + + # Do this after demos run, so that the simple demo code doesn't have to handle + # anything weird we define for testing purposes + include("ruleset_loading.jl") end