diff --git a/src/Functors.jl b/src/Functors.jl index 65fb494..fa3ce9f 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -2,7 +2,7 @@ module Functors using MacroTools -export @functor, fmap, fcollect +export @functor, fmap, fmapstructure, fcollect include("functor.jl") diff --git a/src/functor.jl b/src/functor.jl index f351108..90eb0fa 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -13,7 +13,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) end escfs = [:($f=x.$f) for f in fs] - + @eval m begin $Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) end @@ -44,21 +44,76 @@ Equivalent to `functor(x)[1]`. """ children(x) = functor(x)[1] -function fmap1(f, x) +function _default_walk(f, x) func, re = functor(x) re(map(f, func)) end -# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for -# cache. -function fmap(f, x; exclude = isleaf, cache = IdDict()) +""" + fmap(f, x; exclude = isleaf, walk = Functors._default_walk) + +A structure and type preserving `map` that works for all [`functor`](@ref)s. + +By default, traverses `x` recursively using [`functor`](@ref) +and transforms every leaf node identified by `exclude` with `f`. + +For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. +This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> struct Bar; x; end + +julia> @functor Bar + +julia> m = Foo(Bar([1,2,3]), (4, 5)); + +julia> fmap(x -> 2x, m) +Foo(Bar([2, 4, 6]), (8, 10)) + +julia> fmap(string, m) +Foo(Bar("[1, 2, 3]"), ("4", "5")) + +julia> fmap(string, m, exclude = v -> v isa Bar) +Foo("Bar([1, 2, 3])", (4, 5)) + +julia> fmap(x -> 2x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) +Foo(Bar([1, 2, 3]), (8, 10)) +``` +""" +function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) haskey(cache, x) && return cache[x] - y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x) + y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) cache[x] = y return y end +""" + fmapstructure(f, x; exclude = isleaf) + +Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Instead, it returns a (potentially nested) `NamedTuple`. + +Useful for when the output must not contain custom structs. + +# Examples +```jldoctest +julia> struct Foo; x; y; end + +julia> @functor Foo + +julia> m = Foo([1,2,3], (4, 5)); + +julia> fmapstructure(x -> 2x, m) +(x = [2, 4, 6], y = (8, 10)) +``` +""" +fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) + """ fcollect(x; exclude = v -> false) @@ -68,7 +123,7 @@ and collecting the results into a flat array. Doesn't recurse inside branches rooted at nodes `v` for which `exclude(v) == true`. In such cases, the root `v` is also excluded from the result. -By default, `exclude` always yields `false`. +By default, `exclude` always yields `false`. See also [`children`](@ref). @@ -83,7 +138,7 @@ julia> struct Bar; x; end julia> @functor Bar -julia> struct NoChildren; x; y; end +julia> struct NoChildren; x; y; end julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b)) @@ -98,7 +153,7 @@ julia> fcollect(m, exclude = v -> v isa Bar) 2-element Vector{Any}: Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) NoChildren(:a, :b) - + julia> fcollect(m, exclude = v -> Functors.isleaf(v)) 2-element Vector{Any}: Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) diff --git a/test/basics.jl b/test/basics.jl index 9476e4e..7f28513 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -45,6 +45,13 @@ end @test fmap(f, x; exclude = x -> x isa AbstractArray) == x end +@testset "Walk" begin + model = Foo((0, Bar([1, 2, 3])), [4, 5]) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) +end + @testset "Property list" begin model = Baz(1, 2, 3) model′ = fmap(x -> 2x, model)