From a2c2005b241fecd70b8aa92f02ff5c870394ca76 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 26 Nov 2020 19:58:15 +0000 Subject: [PATCH 01/11] add rrule --- src/rule_definition_tools.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 7026c0065..fc4f421f9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -287,9 +287,7 @@ julia> frule((Zero(), 1), length, [2.0, 3.0]) """ macro non_differentiable(sig_expr) Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`") - for arg in sig_expr.args - _isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg") - end + has_vararg = _isvararg(sig_expr.args[end]) primal_name, orig_args = Iterators.peel(sig_expr.args) @@ -298,7 +296,13 @@ macro non_differentiable(sig_expr) unconstrained_args = _unconstrain.(constrained_args) - primal_invoke = :($(primal_name)($(unconstrained_args...); kwargs...)) + primal_invoke = if !has_vararg + :($(primal_name)($(unconstrained_args...); kwargs...)) + else + normal_args = unconstrained_args[1:end-1] + var_arg = unconstrained_args[end] + :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) + end quote $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) @@ -316,13 +320,22 @@ function _nondiff_frule_expr(primal_sig_parts, primal_invoke) end function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) - num_primal_inputs = length(primal_sig_parts) - 1 + has_vararg = _isvararg(primal_sig_parts[end]) + if !has_vararg + num_primal_inputs = length(primal_sig_parts) - 1 # - primal + tup_expr = Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) + else + num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg + length_expr = :(num_primal_inputs + length($(_unconstrain(primal_sig_parts[end])))) + tup_expr = Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) + end primal_name = first(primal_invoke.args) pullback_expr = Expr( :function, Expr(:call, propagator_name(primal_name, :pullback), :_), - Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) + Expr(:tuple, NO_FIELDS, Expr(:(...), tup_expr)) ) + @show pullback_expr return esc(:( function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) return ($primal_invoke, $pullback_expr) From e21ae1ae4838fe9c345f4b281f66f519e509bc0d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 14:36:27 +0000 Subject: [PATCH 02/11] fix vararg case --- src/rule_definition_tools.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index fc4f421f9..faedae0b3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -323,10 +323,10 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) has_vararg = _isvararg(primal_sig_parts[end]) if !has_vararg num_primal_inputs = length(primal_sig_parts) - 1 # - primal - tup_expr = Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) + tup_expr = Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg - length_expr = :(num_primal_inputs + length($(_unconstrain(primal_sig_parts[end])))) + length_expr = :($(num_primal_inputs) + length($(_unconstrain(primal_sig_parts[end])))) tup_expr = Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) end primal_name = first(primal_invoke.args) From edfd3e5cc290a6714dff21a5b76ea78a93ca0095 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 15:23:17 +0000 Subject: [PATCH 03/11] fix xs... --- src/rule_definition_tools.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index faedae0b3..39cb9bf36 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -400,7 +400,8 @@ end "turn both `a` and `::constraint` into `a::constraint` etc" function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. - Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name + Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name + Meta.isexpr(arg, :(...), 1) && return Expr(:(::), arg.args[1], :Vararg) # turn xs... into xs::Vararg error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type From 0fdc605fc23352818c1ef6656a2c4a145eccef78 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 15:24:31 +0000 Subject: [PATCH 04/11] remove tests for not supporting varargs --- test/rule_definition_tools.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 6ca7c769f..03a45e05a 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -126,10 +126,6 @@ end end @testset "Not supported (Yet)" begin - # Varargs are not supported - @test_macro_throws ErrorException @non_differentiable vararg1(xs...) - @test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg) - # Where clauses are not supported. @test_macro_throws( ErrorException, From f68fdd76ad9c8b58eef5547bd72efd4983a816f3 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 15:52:37 +0000 Subject: [PATCH 05/11] change the docstring --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 39cb9bf36..1b31e03f3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -281,7 +281,7 @@ julia> frule((Zero(), 1), length, [2.0, 3.0]) !!! warning This helper macro covers only the simple common cases. - It does not support Varargs, or `where`-clauses. + It does not support `where`-clauses. For these you can declare the `rrule` and `frule` directly """ From a10696a5a993bed3a63e1e892e570f0b49b87c78 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 16:19:32 +0000 Subject: [PATCH 06/11] add tests --- src/rule_definition_tools.jl | 1 - test/rule_definition_tools.jl | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 1b31e03f3..8d591fc4a 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -335,7 +335,6 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) Expr(:call, propagator_name(primal_name, :pullback), :_), Expr(:tuple, NO_FIELDS, Expr(:(...), tup_expr)) ) - @show pullback_expr return esc(:( function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) return ($primal_invoke, $pullback_expr) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 03a45e05a..0beeba630 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -125,6 +125,39 @@ end @test rrule(NonDiffCounterExample, 2.0) === nothing end + @testset "Varargs" begin + fvarargs(a, xs...) = sum((a, xs...)) + @testset "xs::Vararg{Int}" begin + @non_differentiable fvarargs(a, xs::Vararg{Int}) + + y, pb = rrule(fvarargs, 1) + @test y == fvarargs(1) + @test pb(1) == (Zero(), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 2, 3) + @test y == fvarargs(1, 2, 3) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == nothing + @test rrule(fvarargs, 1, 2.0) == nothing + end + @testset "xs..." begin + @non_differentiable fvarargs(a, xs...) + + y, pb = rrule(fvarargs, 1.) + @test y == fvarargs(1.) + @test pb(1) == (Zero(), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 2, 3) + @test y == fvarargs(1, 2, 3.) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1.), fvarargs, 1, 2.) == (fvarargs(1, 2.), DoesNotExist()) + end + end + @testset "Not supported (Yet)" begin # Where clauses are not supported. @test_macro_throws( From dc3d178438fda19c5d582ad97592b1510a1d956a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 27 Nov 2020 16:26:31 +0000 Subject: [PATCH 07/11] bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bbdecb05b..be7d74ad5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.19" +version = "0.9.20" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From db4ce02ece05563b2d626171d0577a2ec54d321a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 30 Nov 2020 14:38:12 +0000 Subject: [PATCH 08/11] fix example from code-review --- src/rule_definition_tools.jl | 10 ++++++---- test/rule_definition_tools.jl | 14 ++++++++++++++ test/runtests.jl | 1 - 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 8d591fc4a..1b6e991f4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -300,8 +300,9 @@ macro non_differentiable(sig_expr) :($(primal_name)($(unconstrained_args...); kwargs...)) else normal_args = unconstrained_args[1:end-1] - var_arg = unconstrained_args[end] - :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) + var_arg = unconstrained_args[end] # either `xs...`` or `xs`, coming from Vararg + var_arg_call = Meta.isexpr(var_arg, :(...), 1) ? var_arg : Expr(:(...), var_arg) + :($(primal_name)($(normal_args...), $(var_arg_call); kwargs...)) end quote @@ -393,6 +394,7 @@ end _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. + Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _unconstrain(arg.args[1])) error("malformed arguments: $arg") end @@ -400,7 +402,7 @@ end function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) # add name - Meta.isexpr(arg, :(...), 1) && return Expr(:(::), arg.args[1], :Vararg) # turn xs... into xs::Vararg + Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _constrain_and_name(arg.args[1], :Any)) error("malformed arguments: $arg") end -_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type \ No newline at end of file diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0beeba630..57a5f8b48 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -143,6 +143,20 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == nothing @test rrule(fvarargs, 1, 2.0) == nothing end + + @testset "::Float64..." begin + @non_differentiable fvarargs(a, ::Float64...) + + y, pb = rrule(fvarargs, 1.0, 1.0) + @test y == fvarargs(1.0, 1.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + + @test frule((1, 1, 1), fvarargs, 1, 1, 2.0) == nothing + @test rrule(fvarargs, 1, 1, 2.0) == nothing + end + @testset "xs..." begin @non_differentiable fvarargs(a, xs...) diff --git a/test/runtests.jl b/test/runtests.jl index eab854b10..ed22f83c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,6 @@ using Test include("rules.jl") include("rule_definition_tools.jl") - @testset "demos" begin include("demos/forwarddiffzero.jl") include("demos/reversediffzero.jl") From 5f36b03b50bbe6b721d0214a9334adf9fcdaa330 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 30 Nov 2020 16:00:06 +0000 Subject: [PATCH 09/11] more elegant solution and more complete tests --- src/rule_definition_tools.jl | 7 +++--- test/rule_definition_tools.jl | 42 ++++++++++++++--------------------- 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 1b6e991f4..d783638f3 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -300,9 +300,8 @@ macro non_differentiable(sig_expr) :($(primal_name)($(unconstrained_args...); kwargs...)) else normal_args = unconstrained_args[1:end-1] - var_arg = unconstrained_args[end] # either `xs...`` or `xs`, coming from Vararg - var_arg_call = Meta.isexpr(var_arg, :(...), 1) ? var_arg : Expr(:(...), var_arg) - :($(primal_name)($(normal_args...), $(var_arg_call); kwargs...)) + var_arg = unconstrained_args[end] + :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) end quote @@ -394,7 +393,7 @@ end _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. - Meta.isexpr(arg, :(...), 1) && return Expr(:(...), _unconstrain(arg.args[1])) + Meta.isexpr(arg, :(...), 1) && return _unconstrain(arg.args[1]) error("malformed arguments: $arg") end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 57a5f8b48..cef6f9887 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -127,48 +127,40 @@ end @testset "Varargs" begin fvarargs(a, xs...) = sum((a, xs...)) - @testset "xs::Vararg{Int}" begin - @non_differentiable fvarargs(a, xs::Vararg{Int}) + @testset "xs::Float64..." begin + @non_differentiable fvarargs(a, xs::Float64...) y, pb = rrule(fvarargs, 1) @test y == fvarargs(1) @test pb(1) == (Zero(), DoesNotExist()) - y, pb = rrule(fvarargs, 1, 2, 3) - @test y == fvarargs(1, 2, 3) + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) - @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) - @test frule((1, 1), fvarargs, 1, 2.0) == nothing - @test rrule(fvarargs, 1, 2.0) == nothing + @test frule((1, 1), fvarargs, 1, 2) == nothing + @test rrule(fvarargs, 1, 2) == nothing end @testset "::Float64..." begin @non_differentiable fvarargs(a, ::Float64...) - - y, pb = rrule(fvarargs, 1.0, 1.0) - @test y == fvarargs(1.0, 1.0) - @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) - @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) - @test frule((1, 1, 1), fvarargs, 1, 1, 2.0) == nothing - @test rrule(fvarargs, 1, 1, 2.0) == nothing + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) end - - @testset "xs..." begin - @non_differentiable fvarargs(a, xs...) - y, pb = rrule(fvarargs, 1.) - @test y == fvarargs(1.) - @test pb(1) == (Zero(), DoesNotExist()) - - y, pb = rrule(fvarargs, 1, 2, 3) - @test y == fvarargs(1, 2, 3.) - @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + @testset "::Vararg" begin + @non_differentiable fvarargs(a, ::Vararg) + @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) - @test frule((1, 1.), fvarargs, 1, 2.) == (fvarargs(1, 2.), DoesNotExist()) + y, pb = rrule(fvarargs, 1, 1) + @test y == fvarargs(1, 1) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) end end From 6620f80fcb2473bc521496541a519fe4306d1be1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Dec 2020 14:31:22 +0000 Subject: [PATCH 10/11] mini expr change --- src/rule_definition_tools.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index d783638f3..7ff595ff5 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -319,16 +319,20 @@ function _nondiff_frule_expr(primal_sig_parts, primal_invoke) )) end -function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) +function tuple_expression(primal_sig_parts) has_vararg = _isvararg(primal_sig_parts[end]) - if !has_vararg + return if !has_vararg num_primal_inputs = length(primal_sig_parts) - 1 # - primal - tup_expr = Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...) + Expr(:tuple, ntuple(_->DoesNotExist(), num_primal_inputs)...) else num_primal_inputs = length(primal_sig_parts) - 2 # - primal and vararg length_expr = :($(num_primal_inputs) + length($(_unconstrain(primal_sig_parts[end])))) - tup_expr = Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) + Expr(:call, :ntuple, Expr(:(->), :_, DoesNotExist()), length_expr) end +end + +function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) + tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) pullback_expr = Expr( :function, From 04aa23653002d36e3b7fa1ec7644e9625d9c3961 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 1 Dec 2020 14:31:29 +0000 Subject: [PATCH 11/11] add two tests --- test/rule_definition_tools.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index cef6f9887..c4a0df8d3 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -154,6 +154,16 @@ end @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) end + @testset "::Vararg{Float64}" begin + @non_differentiable fvarargs(a, ::Vararg{Float64}) + + y, pb = rrule(fvarargs, 1, 2.0, 3.0) + @test y == fvarargs(1, 2.0, 3.0) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist(), DoesNotExist()) + + @test frule((1, 1), fvarargs, 1, 2.0) == (fvarargs(1, 2.0), DoesNotExist()) + end + @testset "::Vararg" begin @non_differentiable fvarargs(a, ::Vararg) @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) @@ -162,6 +172,15 @@ end @test y == fvarargs(1, 1) @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) end + + @testset "xs..." begin + @non_differentiable fvarargs(a, xs...) + @test frule((1, 1), fvarargs, 1, 2) == (fvarargs(1, 2), DoesNotExist()) + + y, pb = rrule(fvarargs, 1, 1) + @test y == fvarargs(1, 1) + @test pb(1) == (Zero(), DoesNotExist(), DoesNotExist()) + end end @testset "Not supported (Yet)" begin