diff --git a/src/Setfield.jl b/src/Setfield.jl index 62e5f1e..576f9cf 100644 --- a/src/Setfield.jl +++ b/src/Setfield.jl @@ -1,7 +1,7 @@ __precompile__(true) module Setfield using MacroTools -using MacroTools: isstructdef, splitstructdef +using MacroTools: isstructdef, splitstructdef, postwalk include("lens.jl") include("sugar.jl") diff --git a/src/lens.jl b/src/lens.jl index 17d5968..760169b 100644 --- a/src/lens.jl +++ b/src/lens.jl @@ -278,6 +278,15 @@ Base.@propagate_inbounds set(obj, ::ConstIndexLens{I}, val) where I = end end +struct DynamicIndexLens{F} <: Lens + f::F +end + +Base.@propagate_inbounds get(obj, I::DynamicIndexLens) = obj[I.f(obj)...] + +Base.@propagate_inbounds set(obj, I::DynamicIndexLens, val) = + setindex(obj, val, I.f(obj)...) + """ FunctionLens(f) @lens f(_) diff --git a/src/sugar.jl b/src/sugar.jl index dee7bb6..d735047 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -52,6 +52,30 @@ end is_interpolation(x) = x isa Expr && x.head == :$ +foldtree(op, init, x) = op(init, x) +foldtree(op, init, ex::Expr) = + op(foldl((acc, x) -> foldtree(op, acc, x), ex.args; init=init), ex) + +need_dynamic_lens(ex) = + foldtree(false, ex) do yes, x + yes || x === :end || x === :_ + end + +replace_underscore(ex, to) = postwalk(x -> x === :_ ? to : x, ex) + +function lower_index(collection::Symbol, index, dim) + if isexpr(index, :call) + return Expr(:call, lower_index.(collection, index.args, dim)...) + elseif index === :end + if dim === nothing + return :($(Base.lastindex)($collection)) + else + return :($(Base.lastindex)($collection, $dim)) + end + end + return index +end + function parse_obj_lenses(ex) if @capture(ex, front_[indices__]) obj, frontlens = parse_obj_lenses(front) @@ -63,6 +87,12 @@ function parse_obj_lenses(ex) end index = esc(Expr(:tuple, [x.args[1] for x in indices]...)) lens = :(ConstIndexLens{$index}()) + elseif any(need_dynamic_lens, indices) + @gensym collection + indices = replace_underscore.(indices, collection) + dims = length(indices) == 1 ? nothing : 1:length(indices) + lindices = esc.(lower_index.(collection, indices, dims)) + lens = :(DynamicIndexLens($(esc(collection)) -> ($(lindices...),))) else index = esc(Expr(:tuple, indices...)) lens = :(IndexLens($index)) diff --git a/test/test_core.jl b/test/test_core.jl index 7f54dfc..6c2c532 100644 --- a/test/test_core.jl +++ b/test/test_core.jl @@ -104,6 +104,10 @@ end i = 1 si = @set t.a[i] = 10 @test s1 === si + se = @set t.a[end] = 20 + @test se === T((1,20),(3,4)) + se1 = @set t.a[end-1] = 10 + @test s1 === se1 s1 = @set t.a[$1] = 10 @test s1 === T((10,2),(3,4)) @@ -191,6 +195,8 @@ end @lens _.b.a.b[i] @lens _.b.a.b[$2] @lens _.b.a.b[$i] + @lens _.b.a.b[end] + @lens _.b.a.b[identity(end) - 1] @lens _ ] val1, val2 = randn(2) @@ -226,6 +232,8 @@ end ((@lens _.b.a.b[$(i+1)]), 4 ), ((@lens _.b.a.b[$2] ), 4.0), ((@lens _.b.a.b[$(i+1)]), 4.0), + ((@lens _.b.a.b[end]), 4.0), + ((@lens _.b.a.b[end÷2+1]), 4.0), ((@lens _ ), obj), ((@lens _ ), :xy), (MultiPropertyLens((a=(@lens _), b=(@lens _))), (a=1, b=2)), @@ -238,25 +246,51 @@ end @testset "IndexLens" begin l = @lens _[] + @test l isa Setfield.IndexLens x = randn() obj = Ref(x) @test get(obj, l) == x l = @lens _[][] + @test l.outer isa Setfield.IndexLens + @test l.inner isa Setfield.IndexLens inner = Ref(x) obj = Base.RefValue{typeof(inner)}(inner) @test get(obj, l) == x obj = (1,2,3) l = @lens _[1] + @test l isa Setfield.IndexLens @test get(obj, l) == 1 @test set(obj, l, 6) == (6,2,3) l = @lens _[1:3] + @test l isa Setfield.IndexLens @test get([4,5,6,7], l) == [4,5,6] end +@testset "DynamicIndexLens" begin + l = @lens _[end] + @test l isa Setfield.DynamicIndexLens + obj = (1,2,3) + @test get(obj, l) == 3 + @test set(obj, l, true) == (1,2,true) + + l = @lens _[end÷2] + @test l isa Setfield.DynamicIndexLens + obj = (1,2,3) + @test get(obj, l) == 1 + @test set(obj, l, true) == (true,2,3) + + two = 2 + plusone(x) = x + 1 + l = @lens _.a[plusone(end) - two].b + obj = (a=(1, (a=10, b=20), 3), b=4) + @test get(obj, l) == 20 + @test set(obj, l, true) == (a=(1, (a=10, b=true), 3), b=4) +end + @testset "ConstIndexLens" begin obj = (1, 2.0, '3') l = @lens _[$1] diff --git a/test/test_staticarrays.jl b/test/test_staticarrays.jl index 4ec707c..ea9d3e7 100644 --- a/test/test_staticarrays.jl +++ b/test/test_staticarrays.jl @@ -17,5 +17,25 @@ using StaticArrays v = @SVector [1,2,3] @test (@set v[1] = 10) === @SVector [10,2,3] @test_broken (@set v[1] = π) === @SVector [π,2,3] + + @testset "Multi-dynamic indexing" begin + two = 2 + plusone(x) = x + 1 + l1 = @lens _.a[2, 1].b + l2 = @lens _.a[plusone(end) - two, end÷2].b + m_orig = @SMatrix [ + (a=1, b=10) (a=2, b=20) + (a=3, b=30) (a=4, b=40) + (a=5, b=50) (a=6, b=60) + ] + m_mod = @SMatrix [ + (a=1, b=10) (a=2, b=20) + (a=3, b=3000) (a=4, b=40) + (a=5, b=50) (a=6, b=60) + ] + obj = (a=m_orig, b=4) + @test get(obj, l1) === get(obj, l2) === 30 + @test set(obj, l1, 3000) === set(obj, l2, 3000) === (a=m_mod, b=4) + end end end