diff --git a/Project.toml b/Project.toml index bbdecb05b..be7d74ad5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.19" +version = "0.9.20" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 7026c0065..7ff595ff5 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -281,15 +281,13 @@ julia> frule((Zero(), 1), length, [2.0, 3.0]) !!! warning This helper macro covers only the simple common cases. - It does not support Varargs, or `where`-clauses. + It does not support `where`-clauses. For these you can declare the `rrule` and `frule` directly """ macro non_differentiable(sig_expr) Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`") - for arg in sig_expr.args - _isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg") - end + has_vararg = _isvararg(sig_expr.args[end]) primal_name, orig_args = Iterators.peel(sig_expr.args) @@ -298,7 +296,13 @@ macro non_differentiable(sig_expr) unconstrained_args = _unconstrain.(constrained_args) - primal_invoke = :($(primal_name)($(unconstrained_args...); kwargs...)) + primal_invoke = if !has_vararg + :($(primal_name)($(unconstrained_args...); kwargs...)) + else + normal_args = unconstrained_args[1:end-1] + var_arg = unconstrained_args[end] + :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) + end quote $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) @@ -315,13 +319,25 @@ function _nondiff_frule_expr(primal_sig_parts, primal_invoke) )) end +function tuple_expression(primal_sig_parts) + has_vararg = _isvararg(primal_sig_parts[end]) + return if !has_vararg + num_primal_inputs = length(primal_sig_parts) - 1 # - primal + Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...) + else + num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg + length_expr = :($(num_primal_inputs) + length($(_unconstrain(primal_sig_parts[end])))) + Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) + end +end + function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) - num_primal_inputs = length(primal_sig_parts) - 1 + tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) pullback_expr = Expr( :function, Expr(:call, propagator_name(primal_name, :pullback), :_), - Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) + Expr(:tuple, NO_FIELDS, Expr(:(...), tup_expr)) ) return esc(:( function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) @@ -381,13 +397,15 @@ end _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. + Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) error("malformed arguments: $arg") end "turn both `a` and `::constraint` into `a::constraint` etc" function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. - Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name + Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name + Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end -_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type \ No newline at end of file diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 6ca7c769f..c4a0df8d3 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -125,11 +125,65 @@ end @test rrule(NonDiffCounterExample, 2.0) === nothing end - @testset "Not supported (Yet)" begin - # Varargs are not supported - @test_macro_throws ErrorException @non_differentiable vararg1(xs...) - @test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg) + @testset "Varargs" begin + fvarargs(a, xs...) = sum((a, xs...)) + @testset "xs::Float64..." begin + @non_differentiable fvarargs(a, xs::Float64...) + + y, pb = rrule(fvarargs, 1) + @test y == fvarargs(1) + @test pb(1) == (Zero(), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing + end + + @testset "::Float64..." begin + @non_differentiable fvarargs(a, ::Float64...) + + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + end + + @testset "::Vararg{Float64}" begin + @non_differentiable fvarargs(a, ::Vararg{Float64}) + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + end + + @testset "::Vararg" begin + @non_differentiable fvarargs(a, ::Vararg) + @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 1) + @test y == fvarargs(1, 1) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) + end + + @testset "xs..." begin + @non_differentiable fvarargs(a, xs...) + @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 1) + @test y == fvarargs(1, 1) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) + end + end + + @testset "Not supported (Yet)" begin # Where clauses are not supported. @test_macro_throws( ErrorException, diff --git a/test/runtests.jl b/test/runtests.jl index eab854b10..ed22f83c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,6 @@ using Test include("rules.jl") include("rule_definition_tools.jl") - @testset "demos" begin include("demos/forwarddiffzero.jl") include("demos/reversediffzero.jl")