From 543a3d2a8ec10f974d13ba1caf1a13501cd699f0 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 1 Jul 2021 13:29:48 -0700 Subject: [PATCH 1/4] Add `walk` kwarg to `fmap` and `fmapstructure` --- src/Functors.jl | 2 +- src/functor.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++---- test/basics.jl | 7 ++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index 65fb494..fa3ce9f 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -2,7 +2,7 @@ module Functors using MacroTools -export @functor, fmap, fcollect +export @functor, fmap, fmapstructure, fcollect include("functor.jl") diff --git a/src/functor.jl b/src/functor.jl index f351108..b62fa83 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -49,16 +49,68 @@ function fmap1(f, x) re(map(f, func)) end -# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for -# cache. -function fmap(f, x; exclude = isleaf, cache = IdDict()) +""" + fmap(f, x; exclude = isleaf, walk = (f′, x) -> ...) + +A structure and type preserving `map` that works for all [`functor`](@ref)s. + +By default, traveres `x` recursively using [`functor`](@ref) +and transforms every leaf node identified by `exclude` with `f`. + +For advanced customization of the traversal behaviour, pass a custom `walk` function. +This function must itself call the continuation f′ to continue traversal. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> struct Bar; x; end + +julia> @functor Bar + +julia> m = Foo(Bar([1,2,3]), (4, 5)); + +julia> fmap(x -> 2x, m) +Foo(Bar([2, 4, 6]), (8, 10)) + +julia> fmap(string, m) +Foo(Bar("[1, 2, 3]"), ("4", "5")) + +julia> fmap(string, m, exclude = v -> v isa Bar) +Foo("Bar([1, 2, 3])", (4, 5)) +``` +""" +function fmap(f, x; exclude = isleaf, walk = fmap1, cache = IdDict()) haskey(cache, x) && return cache[x] - y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x) + y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) cache[x] = y return y end +""" + fmapstructure(f, x; exclude = isleaf) + +Like [`fmap`](@ref), but doesn't preserve the type of custom structs. + +Useful for when the output must be plain-old julia data structures. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> m = Foo([1,2,3], (4, 5)); + +julia> fmapstructure(x -> 2x, m) +(x = [2, 4, 6], y = (8, 10)) +``` +""" +fmapstructure(f, x; kwargs...) = fmap(f, x; walk=(f, x) -> map(f, children(x)), kwargs...) + """ fcollect(x; exclude = v -> false) diff --git a/test/basics.jl b/test/basics.jl index 9476e4e..7f28513 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -45,6 +45,13 @@ end @test fmap(f, x; exclude = x -> x isa AbstractArray) == x end +@testset "Walk" begin + model = Foo((0, Bar([1, 2, 3])), [4, 5]) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +end + @testset "Property list" begin model = Baz(1, 2, 3) model′ = fmap(x -> 2x, model) From 8ac77d844c51f4dc5de6d30c9b6d4227c8306d77 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 5 Jul 2021 15:57:03 -0700 Subject: [PATCH 2/4] s/fmap1/_default_walk --- src/functor.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index b62fa83..1e57291 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -13,7 +13,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) end escfs = [:($f=x.$f) for f in fs] - + @eval m begin $Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) end @@ -44,14 +44,14 @@ Equivalent to `functor(x)[1]`. """ children(x) = functor(x)[1] -function fmap1(f, x) +function _default_walk(f, x) func, re = functor(x) re(map(f, func)) end """ - fmap(f, x; exclude = isleaf, walk = (f′, x) -> ...) - + fmap(f, x; exclude = isleaf, walk = (f′, x) -> ...) + A structure and type preserving `map` that works for all [`functor`](@ref)s. By default, traveres `x` recursively using [`functor`](@ref) @@ -82,7 +82,7 @@ julia> fmap(string, m, exclude = v -> v isa Bar) Foo("Bar([1, 2, 3])", (4, 5)) ``` """ -function fmap(f, x; exclude = isleaf, walk = fmap1, cache = IdDict()) +function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) haskey(cache, x) && return cache[x] y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) cache[x] = y @@ -92,7 +92,7 @@ end """ fmapstructure(f, x; exclude = isleaf) - + Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Useful for when the output must be plain-old julia data structures. @@ -120,7 +120,7 @@ and collecting the results into a flat array. Doesn't recurse inside branches rooted at nodes `v` for which `exclude(v) == true`. In such cases, the root `v` is also excluded from the result. -By default, `exclude` always yields `false`. +By default, `exclude` always yields `false`. See also [`children`](@ref). @@ -135,7 +135,7 @@ julia> struct Bar; x; end julia> @functor Bar -julia> struct NoChildren; x; y; end +julia> struct NoChildren; x; y; end julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b)) @@ -150,7 +150,7 @@ julia> fcollect(m, exclude = v -> v isa Bar) 2-element Vector{Any}: Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) NoChildren(:a, :b) - + julia> fcollect(m, exclude = v -> Functors.isleaf(v)) 2-element Vector{Any}: Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) From 2852ce2ee12cae8eb3406affac048c51cffbd8df Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 5 Jul 2021 21:52:15 -0700 Subject: [PATCH 3/4] Docstring and formatting suggesions Co-authored-by: Kyle Daruwalla --- src/functor.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 1e57291..c76d210 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -50,15 +50,15 @@ function _default_walk(f, x) end """ - fmap(f, x; exclude = isleaf, walk = (f′, x) -> ...) + fmap(f, x; exclude = isleaf, walk = Functors._default_walk) A structure and type preserving `map` that works for all [`functor`](@ref)s. By default, traveres `x` recursively using [`functor`](@ref) and transforms every leaf node identified by `exclude` with `f`. -For advanced customization of the traversal behaviour, pass a custom `walk` function. -This function must itself call the continuation f′ to continue traversal. +For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. +This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. # Examples ```jldoctest @@ -93,9 +93,9 @@ end """ fmapstructure(f, x; exclude = isleaf) -Like [`fmap`](@ref), but doesn't preserve the type of custom structs. +Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Instead, it returns a (potentially nested) `NamedTuple`. -Useful for when the output must be plain-old julia data structures. +Useful for when the output must not contain custom structs. # Examples ```jldoctest @@ -109,7 +109,7 @@ julia> fmapstructure(x -> 2x, m) (x = [2, 4, 6], y = (8, 10)) ``` """ -fmapstructure(f, x; kwargs...) = fmap(f, x; walk=(f, x) -> map(f, children(x)), kwargs...) +fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) """ fcollect(x; exclude = v -> false) From f7030c36ff4ce5310d9ecb7a6aefc0e311f22ba0 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 8 Jul 2021 16:50:21 -0700 Subject: [PATCH 4/4] Typo and `walk` example --- src/functor.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index c76d210..90eb0fa 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -54,7 +54,7 @@ end A structure and type preserving `map` that works for all [`functor`](@ref)s. -By default, traveres `x` recursively using [`functor`](@ref) +By default, traverses `x` recursively using [`functor`](@ref) and transforms every leaf node identified by `exclude` with `f`. For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. @@ -80,6 +80,9 @@ Foo(Bar("[1, 2, 3]"), ("4", "5")) julia> fmap(string, m, exclude = v -> v isa Bar) Foo("Bar([1, 2, 3])", (4, 5)) + +julia> fmap(x -> 2x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) +Foo(Bar([1, 2, 3]), (8, 10)) ``` """ function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())