diff --git a/Project.toml b/Project.toml index ca2aa59..c82d5ff 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "Zygote"] diff --git a/src/functor.jl b/src/functor.jl index b08acf7..616e66e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -75,10 +75,29 @@ Equivalent to `functor(x)[1]`. """ children(x) = functor(x)[1] +function functor_tuple(f, x::Tuple, dx::Tuple) + map(x, dx) do x, x̄ + _default_walk(f, x, x̄) + end +end +functor_tuple(f, x, dx) = f(x, dx) +functor_tuple(f, x, ::Nothing) = x + +# @functor Chain +# Chain -> func = (layers = (Dense,Dense),), gs -> (layers...) +function _default_walk(f, x, dx) + func, re = functor(x) + map(func, dx) do x, x̄ + # functor_tuple(f, x, x̄) + f(x, x̄) + end |> re +end + function _default_walk(f, x) func, re = functor(x) re(map(f, func)) end +_default_walk(f, ::Nothing, ::Nothing) = nothing """ fmap(f, x; exclude = isleaf, walk = Functors._default_walk) @@ -205,3 +224,11 @@ function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) end return output end + +# Allow gradients and other constructs that match the structure of the functor +# to allow for `map` style computations and return a modified version of the struct. +# This way we can use `fmap` to update the params with their gradients +function fmap(f, x, dx...; cache = IdDict()) + haskey(cache, x) && return cache[x] + cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...) +end diff --git a/test/runtests.jl b/test/runtests.jl index 4ccc029..a145c63 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,8 @@ using Functors, Test +using Zygote @testset "Functors.jl" begin -include("basics.jl") -include("base.jl") - + include("basics.jl") + include("update.jl") end diff --git a/test/update.jl b/test/update.jl new file mode 100644 index 0000000..0ed6bca --- /dev/null +++ b/test/update.jl @@ -0,0 +1,23 @@ +@testset "Generalized fmap over equivalent functors" begin + struct M{F,T,S} + σ::F + W::T + b::S + end + + @functor M + + (m::M)(x) = m.σ.(m.W * x .+ m.b) + + m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3)) + x = ones(Float32, 4, 2) + m̄, _ = gradient((m,x) -> sum(m(x)), m, x) + m̂ = Functors.fmap(m, m̄) do x, y + isnothing(x) && return y + isnothing(y) && return x + x .- 0.1f0 .* y + end + + @test m̂.W ≈ fill(0.8f0, size(m.W)) + @test m̂.b ≈ fill(-0.2f0, size(m.b)) +end