From 23ec91dd25ed1870488fb682aa660b8281bc081d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Jul 2021 12:22:59 +0100 Subject: [PATCH 1/6] minor cleanup of rule_definition_tools.jl --- src/ChainRulesCore.jl | 1 + src/rule_definition_tools.jl | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 1114bd6fd..fe54e666f 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,5 +1,6 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! +using Base.Meta using LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC using Compat: hasfield diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 830b5bba8..c709c89c0 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,9 +1,7 @@ # These are some macros (and supporting functions) to make it easier to define rules. -using Base.Meta -macro strip_linenos(expr) - return esc(Base.remove_linenums!(expr)) -end +############################################################################################ +### @scalar_rule """ @scalar_rule(f(x₁, x₂, ...), @@ -88,7 +86,6 @@ macro scalar_rule(call, maybe_setup, partials...) frule_expr = scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) rrule_expr = scalar_rrule_expr(__source__, f, call, setup_stmts, inputs, partials) - ############################################################################ # Final return: building the expression to insert in the place of this macro code = quote if !($f isa Type) && fieldcount(typeof($f)) > 0 @@ -114,7 +111,6 @@ returns (in order) the correctly escaped: - `partials`: which are all `Expr{:tuple,...}` """ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) - ############################################################################ # Setup: normalizing input form etc if Meta.isexpr(maybe_setup, :macrocall) && maybe_setup.args[1] == Symbol("@setup") @@ -275,6 +271,9 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) +############################################################################################ +### @non_differentiable + """ @non_differentiable(signature_expression) @@ -324,7 +323,7 @@ macro non_differentiable(sig_expr) :($(primal_name)($(unconstrained_args...))) else normal_args = unconstrained_args[1:end-1] - var_arg = unconstrained_args[end] + var_arg = s[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -393,10 +392,13 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end - -########### +############################################################################################ # Helpers +macro strip_linenos(expr) + return esc(Base.remove_linenums!(expr)) +end + """ _isvararg(expr) From 8fa0ecba27b6a9c04e7e3628d8f474b03e93d09e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 19 Jul 2021 15:54:22 +0100 Subject: [PATCH 2/6] opting out of rules --- docs/make.jl | 1 + docs/src/api.md | 2 + docs/src/opting_out_of_rules.md | 92 +++++++++++++++++++++++++++++++++ src/ChainRulesCore.jl | 2 +- src/rule_definition_tools.jl | 73 ++++++++++++++++++++++++-- src/rules.jl | 49 ++++++++++++++++++ test/rules.jl | 28 ++++++++++ 7 files changed, 242 insertions(+), 5 deletions(-) create mode 100644 docs/src/opting_out_of_rules.md diff --git a/docs/make.jl b/docs/make.jl index 1b2063ac4..71c4468c1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -48,6 +48,7 @@ makedocs( "Introduction" => "index.md", "FAQ" => "FAQ.md", "Rule configurations and calling back into AD" => "config.md", + "Opting out of rules" => "opting_out_of_rules.md", "Writing Good Rules" => "writing_good_rules.md", "Complex Numbers" => "complex.md", "Deriving Array Rules" => "arrays.md", diff --git a/docs/src/api.md b/docs/src/api.md index 3b48151aa..491749f04 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -50,4 +50,6 @@ ProjectTo ```@docs ChainRulesCore.AbstractTangent ChainRulesCore.debug_mode +ChainRulesCore.no_rrule +ChainRulesCore.no_frule ``` \ No newline at end of file diff --git a/docs/src/opting_out_of_rules.md b/docs/src/opting_out_of_rules.md new file mode 100644 index 000000000..c1f7e5894 --- /dev/null +++ b/docs/src/opting_out_of_rules.md @@ -0,0 +1,92 @@ +# Opting out of rules + +It is common to define rules fairly generically. +Often matching (or exceeding) how generic the matching original primal method is. +Sometimes this is not the correct behavour. +Sometimes the AD can do better than this human defined rule. +If this is generally the case, then we should not have the rule defined at all. +But if it is only the case for a particular set of types, then we want to opt-out just that one. +This is done with the [`@opt_out`](@ref) macro. + +Consider one might have a rrule for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) +```julia +function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:) + y = sum(x; dims=dims) + project = ProjectTo(x) + function sum_pullback(ȳ) + # broadcasting the two works out the size no-matter `dims` + # project makes sure we stay in the same vector subspace as `x` + # no putting in off-diagonal entries in Diagonal etc + x̄ = project(broadcast(last∘tuple, x, ȳ))) + return (NoTangent(), x̄) + end + return y, sum_pullback +end +``` + +That is a fairly reasonable `rrule` for the vast majority of cases. + +You might have a custom array type for which you could write a faster rule. +For example, the pullback for summing a`SkewSymmetric` matrix can be optimizes to basically be `Diagonal(fill(ȳ, size(x,1)))`. +To do that, you can indeed write another more specific [`rrule`](@ref). +But another case is where the AD system itself would generate a more optimized case. + +For example, the a [`NamedDimArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. +It's sum method is basically just to call `sum` on it's parent. +It is entirely conceivable[^1] that the AD system can do better than our `rrule` here. +For example by avoiding the overhead of [`project`ing](@ref ProjectTo). + +To opt-out of using the `rrule` and to allow the AD system to do its own thing we use the +[`@opt_out`](@ref) macro, to say to not use it for sum. + +```julia +@opt_out rrule(::typeof(sum), ::NamedDimsArray) +``` + +We could even opt-out for all 1 arg functions. +```@julia +@opt_out rrule(::Any, ::NamedDimsArray) +``` +Though this is likely to cause some method-ambiguities. + +Similar can be done `@opt_out frule`. +It can also be done passing in a [`RuleConfig`](@ref config). + + +### How to support this (for AD implementers) + +We provide two ways to know that a rule has been opted out of. + +## `rrule` / `frule` returns `nothing` + +`@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`. + +If you are in a position to generate code, in response to values returned by function calls then you can do something like: +```@julia +res = rrule(f, xs) +if res === nothing + y, pullback = perform_ad_via_decomposition(r, xs) # do AD without hitting the rrule +else + y, pullback = res +end +``` +The Julia compiler, will specialize based on inferring the restun type of `rrule`, and so can remove that branch. + +## `no_rrule` / `no_frule` has a method + +`@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref). +The use of this method doesn't matter, what matters is it's method-table. +A simple thing you can do with this is not support opting out. +To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. +This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error. +This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule. + +More complex you can use this to generate code that triggers your AD. +If for a given signature there is a more specific method in the `no_rrule`/`no_frule` method-table, than the one that would be hit from the `rrule`/`frule` table +(Excluding the one that exactly matches which will return `nothing`) then you know that the rule should not be used. +You can, likely by looking at the primal method table, workout which method you would have it if the rule had not been defined, +and then `invoke` it. + + + +[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. \ No newline at end of file diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index fe54e666f..09c3ebc4c 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export frule, rrule # core function export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode export frule_via_ad, rrule_via_ad # definition helper macros -export @non_differentiable, @scalar_rule, @thunk, @not_implemented +export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented export ProjectTo, canonicalize, unthunk # differential operations export add!! # gradient accumulation operations # differentials diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index c709c89c0..cdf424d05 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,5 +1,10 @@ # These are some macros (and supporting functions) to make it easier to define rules. +# Note: must be declared before it is used, which is later in this file. +macro strip_linenos(expr) + return esc(Base.remove_linenums!(expr)) +end + ############################################################################################ ### @scalar_rule @@ -323,7 +328,7 @@ macro non_differentiable(sig_expr) :($(primal_name)($(unconstrained_args...))) else normal_args = unconstrained_args[1:end-1] - var_arg = s[end] + var_arg = unconstrained_args[end] :($(primal_name)($(normal_args...), $(var_arg)...)) end @@ -392,13 +397,73 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) end end + ############################################################################################ -# Helpers +# @opt_out -macro strip_linenos(expr) - return esc(Base.remove_linenums!(expr)) +""" + @opt_out frule([config], _, f, args...) + @opt_out rrule([config], f, args...) + +This allows you to opt-out of a `frule` or `rrule` by providing a more specific method, +that says to use the AD system, to solver it. + +For example, consider some function `foo(x::AbtractArray)`. +In general, you know a efficicent and generic way to implement it's `rrule`. +You do so, (likely making use of [`ProjectTo`](@ref)). +But it actually turns out that for some `FancyArray` type it is better to let the AD do it's +thing. + +Then you would write something like: +```julia +function rrule(::typeof(foo), x::AbstractArray) + foo_pullback(ȳ) = ... + return foo(x), foo_pullback end +@opt_out rrule(::typeof(foo), ::FancyArray) +``` + +This will generate a [`rrule`](@ref) that returns `nothing`, +and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref). + +Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref) +""" +macro opt_out(expr) + no_rule_target = _no_rule_target_rewrite!(deepcopy(expr)) + + return @strip_linenos quote + $(esc(no_rule_target)) = nothing + $(esc(expr)) = nothing + end +end + +function _no_rule_target_rewrite!(call_target::Symbol) + return if call_target == :rrule + :(ChainRulesCore.no_rrule) + elseif call_target == :frule + :(ChainRulesCore.no_frule) + else + error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target") + end +end +_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value) +function _no_rule_target_rewrite!(expr::Expr) + length(expr.args)===0 && error("Malformed method expression. $expr") + if expr.head === :call || expr.head === :where + expr.args[1] = _no_rule_target_rewrite!(expr.args[1]) + elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore + expr = _no_rule_target_rewrite!(expr.args[end]) + else + error("Malformed method expression. $(expr)") + end + return expr +end + + +############################################################################################ +# Helpers + """ _isvararg(expr) diff --git a/src/rules.jl b/src/rules.jl index 0abc81205..ecda813c9 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -139,3 +139,52 @@ const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance function (::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) return rrule_kwfunc(kws, rrule, args...) end + +############################################################## +### Opt out functionality + +const NO_RRULE_DOC = """ + no_rrule + +This is an implementation detail for opting out of [`rrule`](@ref). +It follows the signature for `rrule` exactly. +We use it as a way to store a collection of type-tuples in its method-table. +If something has this defined, it means that it must having a must also have a `rrule`, +that returns `nothing`. + +### Machanics +note: when this says methods `==` or `<:` it actually means: +`parameters(m.sig)[2:end]` rather than the method object `m` itself. + +To decide if should opt-out using this mechanism. + - find the most specific method of `rrule` + - find the most specific method of `no_rrule` + - if the method of `no_rrule` `<:` the method of `rrule`, then should opt-out + +To just ignore the fact that rules can be opted-out from, and that some rules thus return +`nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones +that occur in the method table of `no_rrule`. + +Note also when doing this you must still also handle falling back from rule with config, to +rule without config. + +On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is +simpler to just use that mechanism for opting out; and you don't need to worry about this +at all. +""" + +""" +$NO_RRULE_DOC + +See also [`ChainRulesCore.no_frule`](@ref). +""" +function no_rrule end +no_rrule(::Any, ::Vararg{Any}) = nothing + +""" +$(replace(NO_RRULE_DOC, "rrule"=>"frule")) + +See also [`ChainRulesCore.no_rrule`](@ref). +""" +function no_frule end +no_frule(ȧrgs, f, ::Vararg{Any}) = nothing \ No newline at end of file diff --git a/test/rules.jl b/test/rules.jl index f5247797f..d43ca42d2 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -148,4 +148,32 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test_skip ∂xr isa Float64 # to be made true with projection @test_skip ∂xr ≈ real(∂x) end + + + @testset "@opt_out" begin + first_oa(x, y) = x + @scalar_rule(first_oa(x, y), (1, 0)) + @opt_out ChainRulesCore.rrule(::typeof(first_oa), x::T, y::T) where T<:Float32 + @opt_out( + ChainRulesCore.frule(::Any, ::typeof(first_oa), x::T, y::T) where T<:Float32 + ) + + @testset "rrule" begin + @test rrule(first_oa, 3.0, 4.0)[2](1) == (NoTangent(), 1, 0) + @test rrule(first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_rrule)) do m + m.sig <:Tuple{Any, typeof(first_oa), T, T} where T<:Float32 + end) + end + + @testset "frule" begin + @test frule((NoTangent(), 1,0), first_oa, 3.0, 4.0) == (3.0, 1) + @test frule((NoTangent(), 1,0), first_oa, 3f0, 4f0) === nothing + + @test !isempty(Iterators.filter(methods(ChainRulesCore.no_frule)) do m + m.sig <:Tuple{Any, Any, typeof(first_oa), T, T} where T<:Float32 + end) + end + end end From c542c2f89fcc1b73b459df8a6586d62548c09a5f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 20 Jul 2021 18:08:24 +0100 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Miha Zgubic --- docs/src/opting_out_of_rules.md | 26 +++++++++++++------------- src/rule_definition_tools.jl | 10 +++++----- src/rules.jl | 8 ++++---- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/src/opting_out_of_rules.md b/docs/src/opting_out_of_rules.md index c1f7e5894..c8fed4a27 100644 --- a/docs/src/opting_out_of_rules.md +++ b/docs/src/opting_out_of_rules.md @@ -2,13 +2,13 @@ It is common to define rules fairly generically. Often matching (or exceeding) how generic the matching original primal method is. -Sometimes this is not the correct behavour. +Sometimes this is not the correct behaviour. Sometimes the AD can do better than this human defined rule. If this is generally the case, then we should not have the rule defined at all. But if it is only the case for a particular set of types, then we want to opt-out just that one. This is done with the [`@opt_out`](@ref) macro. -Consider one might have a rrule for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) +Consider one a `rrule` for `sum` (the following simplified from the one in [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/Base/mapreduce.jl) itself) ```julia function rrule(::typeof(sum), x::AbstractArray{<:Number}; dims=:) y = sum(x; dims=dims) @@ -27,17 +27,17 @@ end That is a fairly reasonable `rrule` for the vast majority of cases. You might have a custom array type for which you could write a faster rule. -For example, the pullback for summing a`SkewSymmetric` matrix can be optimizes to basically be `Diagonal(fill(ȳ, size(x,1)))`. +For example, the pullback for summing a [`SkewSymmetric` (anti-symmetric)](https://en.wikipedia.org/wiki/Skew-symmetric_matrix) matrix can be optimized to basically be `Diagonal(fill(ȳ, size(x,1)))`. To do that, you can indeed write another more specific [`rrule`](@ref). But another case is where the AD system itself would generate a more optimized case. -For example, the a [`NamedDimArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. -It's sum method is basically just to call `sum` on it's parent. +For example, the [`NamedDimsArray`](https://github.com/invenia/NamedDims.jl) is a thin wrapper around some other array type. +Its sum method is basically just to call `sum` on its parent. It is entirely conceivable[^1] that the AD system can do better than our `rrule` here. For example by avoiding the overhead of [`project`ing](@ref ProjectTo). -To opt-out of using the `rrule` and to allow the AD system to do its own thing we use the -[`@opt_out`](@ref) macro, to say to not use it for sum. +To opt-out of using the generic `rrule` and to allow the AD system to do its own thing we use the +[`@opt_out`](@ref) macro, to say to not use it for sum of `NamedDimsArrays`. ```julia @opt_out rrule(::typeof(sum), ::NamedDimsArray) @@ -53,11 +53,11 @@ Similar can be done `@opt_out frule`. It can also be done passing in a [`RuleConfig`](@ref config). -### How to support this (for AD implementers) +## How to support this (for AD implementers) We provide two ways to know that a rule has been opted out of. -## `rrule` / `frule` returns `nothing` +### `rrule` / `frule` returns `nothing` `@opt_out` defines a `frule` or `rrule` matching the signature that returns `nothing`. @@ -70,12 +70,12 @@ else y, pullback = res end ``` -The Julia compiler, will specialize based on inferring the restun type of `rrule`, and so can remove that branch. +The Julia compiler will specialize based on inferring the return type of `rrule`, and so can remove that branch. -## `no_rrule` / `no_frule` has a method +### `no_rrule` / `no_frule` has a method `@opt_out` also defines a method for [`ChainRulesCore.no_frule`](@ref) or [`ChainRulesCore.no_rrule`](@ref). -The use of this method doesn't matter, what matters is it's method-table. +The body of this method doesn't matter, what matters is that it is a method-table. A simple thing you can do with this is not support opting out. To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error. @@ -89,4 +89,4 @@ and then `invoke` it. -[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. \ No newline at end of file +[^1]: It is also possible, that this is not the case. Benchmark your real uses cases. diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index cdf424d05..8c185b6da 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -405,13 +405,13 @@ end @opt_out frule([config], _, f, args...) @opt_out rrule([config], f, args...) -This allows you to opt-out of a `frule` or `rrule` by providing a more specific method, -that says to use the AD system, to solver it. +This allows you to opt-out of an `frule` or an `rrule` by providing a more specific method, +that says to use the AD system to differentiate it. For example, consider some function `foo(x::AbtractArray)`. -In general, you know a efficicent and generic way to implement it's `rrule`. +In general, you know an efficient and generic way to implement its `rrule`. You do so, (likely making use of [`ProjectTo`](@ref)). -But it actually turns out that for some `FancyArray` type it is better to let the AD do it's +But it actually turns out that for some `FancyArray` type it is better to let the AD do its thing. Then you would write something like: @@ -424,7 +424,7 @@ end @opt_out rrule(::typeof(foo), ::FancyArray) ``` -This will generate a [`rrule`](@ref) that returns `nothing`, +This will generate an [`rrule`](@ref) that returns `nothing`, and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref). Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref) diff --git a/src/rules.jl b/src/rules.jl index ecda813c9..bc90b0454 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -152,9 +152,9 @@ We use it as a way to store a collection of type-tuples in its method-table. If something has this defined, it means that it must having a must also have a `rrule`, that returns `nothing`. -### Machanics -note: when this says methods `==` or `<:` it actually means: -`parameters(m.sig)[2:end]` rather than the method object `m` itself. +### Mechanics +note: when the text below says methods `==` or `<:` it actually means: +`parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself. To decide if should opt-out using this mechanism. - find the most specific method of `rrule` @@ -187,4 +187,4 @@ $(replace(NO_RRULE_DOC, "rrule"=>"frule")) See also [`ChainRulesCore.no_rrule`](@ref). """ function no_frule end -no_frule(ȧrgs, f, ::Vararg{Any}) = nothing \ No newline at end of file +no_frule(ȧrgs, f, ::Vararg{Any}) = nothing From a8ffabc934dc469915bc2c0f047e8fb26931d089 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 20 Jul 2021 18:08:41 +0100 Subject: [PATCH 4/6] Update docs/src/opting_out_of_rules.md Co-authored-by: Miha Zgubic --- docs/src/opting_out_of_rules.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/opting_out_of_rules.md b/docs/src/opting_out_of_rules.md index c8fed4a27..9f3bc1732 100644 --- a/docs/src/opting_out_of_rules.md +++ b/docs/src/opting_out_of_rules.md @@ -78,7 +78,7 @@ The Julia compiler will specialize based on inferring the return type of `rrule` The body of this method doesn't matter, what matters is that it is a method-table. A simple thing you can do with this is not support opting out. To do this, filter all methods from the `rrule`/`frule` method table that also occur in the `no_frule`/`no_rrule` table. -This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` and thus makes your library error. +This will thus avoid ever hitting an `rrule`/`frule` that returns `nothing` (and thus prevents your library from erroring). This is easily done, though it does mean ignoring the user's stated desire to opt out of the rule. More complex you can use this to generate code that triggers your AD. From a66f5e7a702430be600c53aaf9034cf8bbef63a8 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 20 Jul 2021 19:44:37 +0100 Subject: [PATCH 5/6] improve docs --- docs/src/opting_out_of_rules.md | 8 +++++++- src/rule_definition_tools.jl | 2 ++ src/rules.jl | 24 ++++++++++++++++-------- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/docs/src/opting_out_of_rules.md b/docs/src/opting_out_of_rules.md index 9f3bc1732..87a2403f0 100644 --- a/docs/src/opting_out_of_rules.md +++ b/docs/src/opting_out_of_rules.md @@ -1,4 +1,4 @@ -# Opting out of rules +# [Opting out of rules](@id opt_out) It is common to define rules fairly generically. Often matching (or exceeding) how generic the matching original primal method is. @@ -53,6 +53,12 @@ Similar can be done `@opt_out frule`. It can also be done passing in a [`RuleConfig`](@ref config). +!!! warning "If the general rule uses a config, the opt-out must also" + Following the same principles as for [rules with config](@ref config), a rule with a `RuleConfig` argument will take precedence over one without, including if that one is a opt-out rule. + But if the general rule does not use a config, then the opt-out rule *can* use a config. + This allows, for example, you to use opt-out to avoid a particular AD system using a opt-out rule that takes that particular AD's config. + + ## How to support this (for AD implementers) We provide two ways to know that a rule has been opted out of. diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8c185b6da..122cd2441 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -428,6 +428,8 @@ This will generate an [`rrule`](@ref) that returns `nothing`, and will also add a similar entry to [`ChainRulesCore.no_rrule`](@ref). Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref) + +For more information see the [documentation on opting out of rules](@ref opt_out). """ macro opt_out(expr) no_rule_target = _no_rule_target_rewrite!(deepcopy(expr)) diff --git a/src/rules.jl b/src/rules.jl index bc90b0454..1d9113c24 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -146,20 +146,26 @@ end const NO_RRULE_DOC = """ no_rrule -This is an implementation detail for opting out of [`rrule`](@ref). +This is an piece of infastructure supporting opting out of [`rrule`](@ref). It follows the signature for `rrule` exactly. -We use it as a way to store a collection of type-tuples in its method-table. -If something has this defined, it means that it must having a must also have a `rrule`, -that returns `nothing`. +A collection of type-tuples is stored in its method-table. +If something has this defined, it means that it must having a must also have a `rrule`, +defined that returns `nothing`. + +!!! warning "do not overload no_rrule directly + It is fine and intended to query the method table of `no_rrule`. + It is not safe to add to that directly, as corresponding changes also need to be made to + `rrule`. + The [`@opt_out`](@ref) macro does both these things, and so should almost always be used + rather than defining a method of `no_rrule` directly. ### Mechanics -note: when the text below says methods `==` or `<:` it actually means: +note: when the text below says methods `==` it actually means: `parameters(m.sig)[2:end]` (i.e. the signature type tuple) rather than the method object `m` itself. To decide if should opt-out using this mechanism. - - find the most specific method of `rrule` - - find the most specific method of `no_rrule` - - if the method of `no_rrule` `<:` the method of `rrule`, then should opt-out + - find the most specific method of `rrule` and `no_rule` e.g with `Base.which` + - if the method of `no_rrule` `==` the method of `rrule`, then should opt-out To just ignore the fact that rules can be opted-out from, and that some rules thus return `nothing`, then filter the list of methods of `rrule` to remove those that are `==` to ones @@ -171,6 +177,8 @@ rule without config. On the other-hand if your AD can work with `rrule`s that return `nothing`, then it is simpler to just use that mechanism for opting out; and you don't need to worry about this at all. + +For more information see the [documentation on opting out of rules](@ref opt_out) """ """ From db35df7527e09f69abd9c1414cc77ef7370eb8a3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 20 Jul 2021 19:46:42 +0100 Subject: [PATCH 6/6] comment on _no_rule_target_rewrite! --- src/rule_definition_tools.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 122cd2441..13d7d9406 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -440,16 +440,7 @@ macro opt_out(expr) end end -function _no_rule_target_rewrite!(call_target::Symbol) - return if call_target == :rrule - :(ChainRulesCore.no_rrule) - elseif call_target == :frule - :(ChainRulesCore.no_frule) - else - error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target") - end -end -_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value) +"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`." function _no_rule_target_rewrite!(expr::Expr) length(expr.args)===0 && error("Malformed method expression. $expr") if expr.head === :call || expr.head === :where @@ -461,6 +452,17 @@ function _no_rule_target_rewrite!(expr::Expr) end return expr end +_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value) +function _no_rule_target_rewrite!(call_target::Symbol) + return if call_target == :rrule + :(ChainRulesCore.no_rrule) + elseif call_target == :frule + :(ChainRulesCore.no_frule) + else + error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target") + end +end + ############################################################################################