diff --git a/Project.toml b/Project.toml index 0d786877d..e586ea3ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.6" +version = "0.9.7" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index ca2b0a3ce..b16972e51 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,7 @@ using MuladdMacro: @muladd export on_new_rule, refresh_rules # generation tools export frule, rrule # core function -export @scalar_rule, @thunk # definition helper macros +export @non_differentiable, @scalar_rule, @thunk # definition helper macros export canonicalize, extern, unthunk # differential operations # differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2675c7598..b686ffc0d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -117,18 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) @assert Meta.isexpr(call, :call) # Annotate all arguments in the signature as scalars - inputs = map(call.args[2:end]) do arg - esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number)) - end - + inputs = esc.(_constrain_and_name.(call.args[2:end], :Number)) # Remove annotations and escape names for the call - for (i, arg) in enumerate(call.args) - if Meta.isexpr(arg, :(::)) - call.args[i] = esc(first(arg.args)) - else - call.args[i] = esc(arg) - end - end + call.args[2:end] .= _unconstrain.(call.args[2:end]) + call.args = esc.(call.args) # For consistency in code that follows we make all partials tuple expressions partials = map(partials) do partial @@ -143,6 +135,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end + function scalar_frule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs) @@ -178,7 +171,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is a pull-back there is one per output of function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] + Δs = [Symbol(:Δ, i) for i in 1:n_outputs] # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i @@ -189,7 +182,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) # Multi-output functions have pullbacks with a tuple input that will be destructured pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...) pullback = quote - function $(propagator_name(f, :pullback))($pullback_input) + function $(esc(propagator_name(f, :pullback)))($pullback_input) return (NO_FIELDS, $(pullback_returns...)) end end @@ -215,16 +208,14 @@ function propagation_expr(Δs, ∂s, _conj = false) ∂s = map(esc, ∂s) n∂s = length(∂s) - # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression - # literals. + # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals. ∂_mul_Δs = if _conj ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s) else ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) end - # Avoiding the extra `+` operation, it is potentially expensive for vector - # mode AD. + # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD. sumed_∂_mul_Δs = if n∂s > 1 # we use `@.` to broadcast `*` and `+` :(@. +($(∂_mul_Δs...))) @@ -258,3 +249,143 @@ This is able to deal with fairly complex expressions for `f`: propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) + +""" + @non_differentiable(signature_expression) + +A helper to make it easier to declare that a method is not not differentiable. +This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that +return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial +itself which is `NO_FIELDS`) + +Keyword arguments should not be included. + +```jldoctest +julia> @non_differentiable Base.:(==)(a, b) + +julia> _, pullback = rrule(==, 2.0, 3.0); + +julia> pullback(1.0) +(Zero(), DoesNotExist(), DoesNotExist()) +``` + +You can place type-constraints in the signature: +```jldoctest +julia> @non_differentiable Base.length(xs::Union{Number, Array}) + +julia> frule((Zero(), 1), length, [2.0, 3.0]) +(2, DoesNotExist()) +``` + +!!! warning + This helper macro covers only the simple common cases. + It does not support Varargs, or `where`-clauses. + For these you can declare the `rrule` and `frule` directly + +""" +macro non_differentiable(sig_expr) + Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`") + for arg in sig_expr.args + _isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg") + end + + primal_name, orig_args = Iterators.peel(sig_expr.args) + + constrained_args = _constrain_and_name.(orig_args, :Any) + primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] + + unconstrained_args = _unconstrain.(constrained_args) + primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...) + + quote + $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) + $(_nondiff_rrule_expr(primal_sig_parts, primal_invoke)) + end +end + +function _nondiff_frule_expr(primal_sig_parts, primal_invoke) + return Expr( + :(=), + Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...), + # Julia functions always only have 1 output, so just return a single DoesNotExist() + Expr(:tuple, primal_invoke, DoesNotExist()), + ) +end + +function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) + num_primal_inputs = length(primal_sig_parts) - 1 + primal_name = first(primal_invoke.args) + pullback_expr = Expr( + :function, + Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)), + Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) + ) + rrule_defn = Expr( + :(=), + Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...), + Expr(:tuple, primal_invoke, pullback_expr), + ) + return rrule_defn +end + + +########### +# Helpers + +""" + _isvararg(expr) + +returns true if the expression could represent a vararg + +```jldoctest +julia> ChainRulesCore._isvararg(:(x...)) +true + +julia> ChainRulesCore._isvararg(:(x::Int...)) +true + +julia> ChainRulesCore._isvararg(:(::Int...)) +true + +julia> ChainRulesCore._isvararg(:(x::Vararg)) +true + +julia> ChainRulesCore._isvararg(:(x::Vararg{Int})) +true + +julia> ChainRulesCore._isvararg(:(::Vararg)) +true + +julia> ChainRulesCore._isvararg(:(::Vararg{Int})) +true + +julia> ChainRulesCore._isvararg(:(x)) +false +```` +""" +_isvararg(expr) = false +function _isvararg(expr::Expr) + Meta.isexpr(expr, :...) && return true + if Meta.isexpr(expr, :(::)) + constraint = last(expr.args) + constraint == :Vararg && return true + Meta.isexpr(constraint, :curly) && first(constraint.args) == :Vararg && return true + end + return false +end + + +"turn both `a` and `a::S` into `a`" +_unconstrain(arg::Symbol) = arg +function _unconstrain(arg::Expr) + Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. + error("malformed arguments: $arg") +end + +"turn both `a` and `::constraint` into `a::constraint` etc" +function _constrain_and_name(arg::Expr, _) + Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. + Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name + error("malformed arguments: $arg") +end +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl new file mode 100644 index 000000000..7a05cbfec --- /dev/null +++ b/test/rule_definition_tools.jl @@ -0,0 +1,88 @@ +""" +Along same lines as `@test_throws` but to test if a macro throw an exception when it is +expanded. +""" +macro test_macro_throws(err_expr, expr) + quote + err = nothing + try + @macroexpand($(esc(expr))) + catch load_err + # all errors thrown at macro expansion time are LoadErrors, we need to unwrap + @assert load_err isa LoadError + err = load_err.error + end + # Reuse `@test_throws` logic + if err!==nothing + @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) + else + @test_throws $(esc(err_expr)) $(Meta.quot(expr)) + end + end +end + + +@testset "rule_definition_tools.jl" begin + @testset "@non_differentiable" begin + @testset "two input one output function" begin + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] + @non_differentiable nondiff_2_1(::Any, ::Any) + @test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist()) + res, pullback = rrule(nondiff_2_1, 3, 2) + @test res == 7.5 + @test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist()) + end + + @testset "one input, 2-tuple output function" begin + nondiff_1_2(x) = (5.0, 3.0) + @non_differentiable nondiff_1_2(::Any) + @test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist()) + res, pullback = rrule(nondiff_1_2, 3.1) + @test res == (5.0, 3.0) + @test isequal( + pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)), + (NO_FIELDS, DoesNotExist()), + ) + end + + @testset "constrained signature" begin + nonembed_identity(x) = x + @non_differentiable nonembed_identity(::Integer) + + @test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist()) + @test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing + + res, pullback = rrule(nonembed_identity, 2) + @test res == 2 + @test pullback(1.2) == (NO_FIELDS, DoesNotExist()) + + @test rrule(nonembed_identity, 2.0) == nothing + end + + @testset "Pointy UnionAll constraints" begin + pointy_identity(x) = x + @non_differentiable pointy_identity(::Vector{<:AbstractString}) + + @test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist()) + @test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing + + res, pullback = rrule(pointy_identity, ["2"]) + @test res == ["2"] + @test pullback(1.2) == (NO_FIELDS, DoesNotExist()) + + @test rrule(pointy_identity, 2.0) == nothing + end + + @testset "Not supported (Yet)" begin + # Varargs are not supported + @test_macro_throws ErrorException @non_differentiable vararg1(xs...) + @test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg) + + # Where clauses are not supported. + @test_macro_throws( + ErrorException, + (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) + ) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8f995b354..306f846d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ using Test include("ruleset_loading.jl") include("rules.jl") + include("rule_definition_tools.jl") @testset "demos" begin