diff --git a/Project.toml b/Project.toml index ed5ef88a..bfe867a7 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.2.0" +version = "0.2.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varname.jl b/src/varname.jl index 92d17133..f53909c3 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -241,8 +241,18 @@ julia> @varname(x[:, 1][2]).indexing julia> @varname(x[1,2][1+5][45][3]).indexing ((1, 2), (6,), (45,), (3,)) + +julia> let a = [42]; @varname(a[1][end][3]); end +a[1][1][3] ``` +!!! warning + As you can see in the last example, `@varname` does not do any bounds checking! + + Only `begin` and `end` indexing, which depend on the _runtime size_ of the array, does that! + Their usage _requires_ that the array over which the indexing expression is defined, in + order for `firstindex` and `lastindex` to work in the expanded code. + !!! compat "Julia 1.5" Using `begin` in an indexing expression to refer to the first index requires at least Julia 1.5. @@ -254,8 +264,9 @@ end varname(sym::Symbol) = :($(AbstractPPL.VarName){$(QuoteNode(sym))}()) function varname(expr::Expr) if Meta.isexpr(expr, :ref) - sym, inds = vsym(expr), vinds(expr) - return :($(AbstractPPL.VarName){$(QuoteNode(sym))}($inds)) + head = vsym(expr) + inds = vinds(expr, head) + return :($(AbstractPPL.VarName){$(QuoteNode(head))}($inds)) else error("Malformed variable name $(expr)!") end @@ -334,37 +345,128 @@ macro vinds(expr::Union{Expr, Symbol}) end +@static if VERSION < v"1.5.0-DEV.666" + _replace_ref_begin_end(ex, withex) = Base.replace_ref_end_!(copy(ex), withex) + _replace_ref_begin_end(ex::Symbol, withex) = Base.replace_ref_end_!(ex, withex) + _index_replacement_for(s) = :($lastindex($s)) + _index_replacement_for(s, n) = :($lastindex($s, $n)) +else + _replace_ref_begin_end(ex, withex) = Base.replace_ref_begin_end_!(copy(ex), withex) + _replace_ref_begin_end(ex::Symbol, withex) = Base.replace_ref_begin_end_!(ex, withex) + _index_replacement_for(s) = :($firstindex($s)), :($lastindex($s)) + _index_replacement_for(s, n) = :($firstindex($s, $n)), :($lastindex($s, $n)) +end + + """ vinds(expr) Return the indexing part of the [`@varname`](@ref)-compatible expression `expr` as an expression -suitable for input of the [`VarName`](@ref) constructor. +suitable for input of the [`VarName`](@ref) constructor (i.e., a tuple of tuples). ## Examples ```jldoctest -julia> vinds(:(x[end])) -:((((lastindex)(x),),)) +julia> x = [10, 20, 30]; eval(vinds(:(x[end]))) +((3,),) + + +julia> x = [10 20]; eval(vinds(:(x[1, end]))) +((1, 2),) + +julia> x = [[1, 2]]; eval(vinds(:(x[1][end]))) +((1,), (2,)) + +julia> x = ([1, 2], ); eval(vinds(:(x[1][end]))) # tuple +((1,), (2,)) + +julia> x = [fill([[10], [20, 30]], 2, 2, 2)] + if VERSION < v"1.5.0-DEV.666" + eval(vinds(:(x[1][2, end, :][2][end]))) + else + eval(vinds(Meta.parse("x[begin][2, end, :][2][end]"))) + end +((1,), (2, 2, Colon()), (2,), (2,)) -julia> vinds(:(x[1, end])) -:(((1, (lastindex)(x, 2)),)) ``` """ -function vinds end +function vinds(expr, head = vsym(expr)) + # see https://github.com/JuliaLang/julia/blob/bb5b98e72a151c41471d8cc14cacb495d647fb7f/base/views.jl#L17-L75 + indexing = _straighten_indexing(expr) + inds = Expr[] # collection of result indices + partial = head # partial :ref expressions, used in caching + cached_exprs = Vector{Pair{Symbol, Expr}}() # cache for partial expressions going into a let + + for ixs in indexing + # S becomes the name of the cached variable + S = (partial == head) ? head : gensym(:S) + used_S = false + + nixs = length(ixs) + if nixs == 1 + # for linear indexing, we use `lastindex(x)` + ixs[1], used = _replace_ref_begin_end(ixs[1], _index_replacement_for(S)) + used_S |= used + elseif nixs > 1 + # for cartesian indexing, we need `lastindex(x, i)` + for i in eachindex(ixs) + ixs[i], used = _replace_ref_begin_end(ixs[i], _index_replacement_for(S, i)) + used_S |= used + end + end -vinds(expr::Symbol) = Expr(:tuple) -function vinds(expr::Expr) - if Meta.isexpr(expr, :ref) - ex = copy(expr) - @static if VERSION < v"1.5.0-DEV.666" - Base.replace_ref_end!(ex) + if used_S && partial !== head + # cache that expression if we actually used it, and use the new name in the + # partial expression + push!(cached_exprs, S => partial) + partial = Expr(:call, Base.maybeview, S, ixs...) else - Base.replace_ref_begin_end!(ex) + partial = Expr(:call, Base.maybeview, partial, ixs...) end - last = Expr(:tuple, ex.args[2:end]...) - init = vinds(ex.args[1]).args - return Expr(:tuple, init..., last) + + push!(inds, Expr(:tuple, ixs...)) + end + + # finally make the tuple of tuples + tuple_expr = Expr(:tuple, inds...) + + if length(cached_exprs) == 0 + return tuple_expr + else + # construct one big let expression + cached_assignments = [:($S = $partial) for (S, partial) in cached_exprs] + return Expr(:let, Expr(:block, cached_assignments...), tuple_expr) + end +end + + +""" + _straighten_indexing(expr) + +Extract a list of lists of (raw) indices of an iterated `:ref` expression. + +```julia +julia> _straighten_indexing(:(x[begin][2, end, :][2][end])) +4-element Array{Array{Any,1},1}: + [:begin] + [2, :end, :(:)] + [2] + [:end] +``` +""" +_straighten_indexing(expr::Symbol) = Vector{Any}[] +function _straighten_indexing(expr::Expr) + if Meta.isexpr(expr, :ref) + init = _straighten_indexing(expr.args[1]) + last = expr.args[2:end] + return push!(init, last) else error("Mis-formed variable name $(expr)!") end end + + + + + +