Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Functors

using MacroTools

export @functor, fmap, fcollect
export @functor, fmap, fmapstructure, fcollect

include("functor.jl")

Expand Down
73 changes: 64 additions & 9 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of no-op line changes. Do you know what that's from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, my editor may have trimmed existing trailing whitespace automatically on save?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, so we'll keep these

@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
end
Expand Down Expand Up @@ -44,21 +44,76 @@ Equivalent to `functor(x)[1]`.
"""
children(x) = functor(x)[1]

function fmap1(f, x)
function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end

# See https://github.com/FluxML/Functors.jl/issues/2 for a discussion regarding the need for
# cache.
function fmap(f, x; exclude = isleaf, cache = IdDict())
"""
fmap(f, x; exclude = isleaf, walk = Functors._default_walk)

A structure and type preserving `map` that works for all [`functor`](@ref)s.

By default, traverses `x` recursively using [`functor`](@ref)
and transforms every leaf node identified by `exclude` with `f`.

For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`.
This function walks (maps) over `xs` calling the continuation `f'` to continue traversal.

# Examples
```jldoctest
julia> struct Foo; x; y; end

julia> @functor Foo

julia> struct Bar; x; end

julia> @functor Bar

julia> m = Foo(Bar([1,2,3]), (4, 5));

julia> fmap(x -> 2x, m)
Foo(Bar([2, 4, 6]), (8, 10))

julia> fmap(string, m)
Foo(Bar("[1, 2, 3]"), ("4", "5"))

julia> fmap(string, m, exclude = v -> v isa Bar)
Foo("Bar([1, 2, 3])", (4, 5))

julia> fmap(x -> 2x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x))
Foo(Bar([1, 2, 3]), (8, 10))
```
"""
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : fmap1(x -> fmap(f, x, cache = cache, exclude = exclude), x)
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y

return y
end

"""
fmapstructure(f, x; exclude = isleaf)

Like [`fmap`](@ref), but doesn't preserve the type of custom structs. Instead, it returns a (potentially nested) `NamedTuple`.

Useful for when the output must not contain custom structs.

# Examples
```jldoctest
julia> struct Foo; x; y; end

julia> @functor Foo

julia> m = Foo([1,2,3], (4, 5));

julia> fmapstructure(x -> 2x, m)
(x = [2, 4, 6], y = (8, 10))
```
"""
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)

"""
fcollect(x; exclude = v -> false)

Expand All @@ -68,7 +123,7 @@ and collecting the results into a flat array.
Doesn't recurse inside branches rooted at nodes `v`
for which `exclude(v) == true`.
In such cases, the root `v` is also excluded from the result.
By default, `exclude` always yields `false`.
By default, `exclude` always yields `false`.

See also [`children`](@ref).

Expand All @@ -83,7 +138,7 @@ julia> struct Bar; x; end

julia> @functor Bar

julia> struct NoChildren; x; y; end
julia> struct NoChildren; x; y; end

julia> m = Foo(Bar([1,2,3]), NoChildren(:a, :b))

Expand All @@ -98,7 +153,7 @@ julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
NoChildren(:a, :b)

julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Expand Down
7 changes: 7 additions & 0 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ end
@test fmap(f, x; exclude = x -> x isa AbstractArray) == x
end

@testset "Walk" begin
model = Foo((0, Bar([1, 2, 3])), [4, 5])

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])
end

@testset "Property list" begin
model = Baz(1, 2, 3)
model′ = fmap(x -> 2x, model)
Expand Down