From 1ac81f089d63ab1f62cf7cb7aeec1d0a3da88802 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 24 Apr 2020 12:23:04 +0530 Subject: [PATCH 1/9] handle dmap1(f, x, dx) --- src/functor.jl | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index 3d6377d..95054be 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -9,7 +9,7 @@ functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x function makefunctor(m::Module, T, fs = fieldnames(T)) @eval m begin - Flux.functor(::Type{<:$T}, x) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) + Functors.functor(::Type{<:$T}, x) = ($([:($f=x.$f) for f in fs]...),), y -> $T(y...) end end @@ -25,6 +25,24 @@ end isleaf(x) = functor(x)[1] === () +# for Chain +function t(f, x::Tuple, dx::Tuple) + map(x, dx) do x, x̄ + fmap1(f, x, x̄) + end +end +t(f, x, dx) = f(x, dx) +t(f, x, ::Nothing) = x + +# @functor Chain +# Chain -> func = (layers = (Dense,Dense),), gs -> (layers...) +function fmap1(f, x, dx) + func, re = functor(x) + map(func, dx) do x, x̄ + t(f, x, x̄) + end |> re +end + function fmap1(f, x) func, re = functor(x) re(map(f, func)) @@ -34,3 +52,8 @@ function fmap(f, x; cache = IdDict()) haskey(cache, x) && return cache[x] cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) end + +function fmap(f, x, dx; cache = IdDict()) + haskey(cache, x) && return cache[x] + cache[x] = isleaf(x) ? f(x, dx) : fmap1((x...) -> fmap(f, x..., cache = cache), x, dx) +end From f8cb762b4a6cd955cd4f3a16b2c7d7720b4d9430 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 24 Apr 2020 12:25:44 +0530 Subject: [PATCH 2/9] rename util fn --- src/functor.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 95054be..9cf0819 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -26,20 +26,20 @@ end isleaf(x) = functor(x)[1] === () # for Chain -function t(f, x::Tuple, dx::Tuple) +function functor_tuple(f, x::Tuple, dx::Tuple) map(x, dx) do x, x̄ fmap1(f, x, x̄) end end -t(f, x, dx) = f(x, dx) -t(f, x, ::Nothing) = x +functor_tuple(f, x, dx) = f(x, dx) +functor_tuple(f, x, ::Nothing) = x # @functor Chain # Chain -> func = (layers = (Dense,Dense),), gs -> (layers...) function fmap1(f, x, dx) func, re = functor(x) map(func, dx) do x, x̄ - t(f, x, x̄) + functor_tuple(f, x, x̄) end |> re end From 96386152092a3f3c74befae96e4da1e08823ae7e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Thu, 5 Nov 2020 16:49:30 +0530 Subject: [PATCH 3/9] make dx a vararg --- src/functor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 659ad23..45829fc 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -55,7 +55,7 @@ function fmap(f, x; cache = IdDict()) cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) end -function fmap(f, x, dx; cache = IdDict()) +function fmap(f, x, dx...; cache = IdDict()) haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x, dx) : fmap1((x...) -> fmap(f, x..., cache = cache), x, dx) + cache[x] = isleaf(x) ? f(x, dx...) : fmap1((x...) -> fmap(f, x..., cache = cache), x, dx...) end From 9b04a3523ceac602af0219be5d893a5d1e13f2cc Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:01:12 +0530 Subject: [PATCH 4/9] add comments --- src/functor.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index 45829fc..83b78b1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -54,7 +54,9 @@ function fmap(f, x; cache = IdDict()) haskey(cache, x) && return cache[x] cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) end - +# Allow gradients and other constructs that match the structure of the functor +# to allow for `map` style computations and return a modified version of the struct. +# This way we can use `fmap` to update the params with their gradients function fmap(f, x, dx...; cache = IdDict()) haskey(cache, x) && return cache[x] cache[x] = isleaf(x) ? f(x, dx...) : fmap1((x...) -> fmap(f, x..., cache = cache), x, dx...) From 10ec1aee73d4cdccf5d7fe5d0074b08a860b52b8 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 3 Feb 2021 18:02:01 +0530 Subject: [PATCH 5/9] whitespace --- src/functor.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/functor.jl b/src/functor.jl index b838022..308cdc5 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -60,6 +60,7 @@ function fmap(f, x; cache = IdDict()) haskey(cache, x) && return cache[x] cache[x] = isleaf(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache), x) end + # Allow gradients and other constructs that match the structure of the functor # to allow for `map` style computations and return a modified version of the struct. # This way we can use `fmap` to update the params with their gradients From 13348e88d4506730305b0647f46712c2726bc47f Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 20 Jun 2021 18:13:20 +0530 Subject: [PATCH 6/9] invariant functor call --- src/functor.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index f29762f..f91a0a1 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -58,7 +58,8 @@ functor_tuple(f, x, ::Nothing) = x function fmap1(f, x, dx) func, re = functor(x) map(func, dx) do x, x̄ - functor_tuple(f, x, x̄) + # functor_tuple(f, x, x̄) + f(x, x̄) end |> re end From ac8a9007b644bd8a80ae5d4d6b4041dd3101ddeb Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 21 Jul 2021 20:00:45 +0530 Subject: [PATCH 7/9] handle nothing-nothing --- src/functor.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index e056004..91d60dc 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -46,7 +46,7 @@ children(x) = functor(x)[1] function functor_tuple(f, x::Tuple, dx::Tuple) map(x, dx) do x, x̄ - fmap1(f, x, x̄) + _default_walk(f, x, x̄) end end functor_tuple(f, x, dx) = f(x, dx) @@ -66,6 +66,7 @@ function _default_walk(f, x) func, re = functor(x) re(map(f, func)) end +_default_walk(f, ::Nothing, ::Nothing) = nothing """ fmap(f, x; exclude = isleaf, walk = Functors._default_walk) @@ -193,5 +194,5 @@ end # This way we can use `fmap` to update the params with their gradients function fmap(f, x, dx...; cache = IdDict()) haskey(cache, x) && return cache[x] - cache[x] = isleaf(x) ? f(x, dx...) : fmap1((x...) -> fmap(f, x..., cache = cache), x, dx...) + cache[x] = isleaf(x) ? f(x, dx...) : _default_walk((x...) -> fmap(f, x..., cache = cache), x, dx...) end From 3354cfb185e172f976ffb9d318607c158fe1921c Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 21 Jul 2021 20:03:39 +0530 Subject: [PATCH 8/9] add Zygote to test --- Project.toml | 3 ++- test/runtests.jl | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 36ff12a..557b27b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ julia = "1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test"] +test = ["Test", "Zygote"] diff --git a/test/runtests.jl b/test/runtests.jl index d824e63..a145c63 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,8 @@ using Functors, Test +using Zygote @testset "Functors.jl" begin -include("basics.jl") - + include("basics.jl") + include("update.jl") end From 101f643b9101be15b54e64a6026b768389a2c17b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 21 Jul 2021 20:04:14 +0530 Subject: [PATCH 9/9] add test for generic f-map --- test/update.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 test/update.jl diff --git a/test/update.jl b/test/update.jl new file mode 100644 index 0000000..0ed6bca --- /dev/null +++ b/test/update.jl @@ -0,0 +1,23 @@ +@testset "Generalized fmap over equivalent functors" begin + struct M{F,T,S} + σ::F + W::T + b::S + end + + @functor M + + (m::M)(x) = m.σ.(m.W * x .+ m.b) + + m = M(identity, ones(Float32, 3, 4), zeros(Float32, 3)) + x = ones(Float32, 4, 2) + m̄, _ = gradient((m,x) -> sum(m(x)), m, x) + m̂ = Functors.fmap(m, m̄) do x, y + isnothing(x) && return y + isnothing(y) && return x + x .- 0.1f0 .* y + end + + @test m̂.W ≈ fill(0.8f0, size(m.W)) + @test m̂.b ≈ fill(-0.2f0, size(m.b)) +end