Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.2.0"
version = "0.2.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
138 changes: 120 additions & 18 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,18 @@ julia> @varname(x[:, 1][2]).indexing

julia> @varname(x[1,2][1+5][45][3]).indexing
((1, 2), (6,), (45,), (3,))

julia> let a = [42]; @varname(a[1][end][3]); end
a[1][1][3]
```

!!! warning
As you can see in the last example, `@varname` does not do any bounds checking!

Only `begin` and `end` indexing, which depend on the _runtime size_ of the array, does that!
Their usage _requires_ that the array over which the indexing expression is defined, in
order for `firstindex` and `lastindex` to work in the expanded code.

!!! compat "Julia 1.5"
Using `begin` in an indexing expression to refer to the first index requires at least
Julia 1.5.
Expand All @@ -254,8 +264,9 @@ end
varname(sym::Symbol) = :($(AbstractPPL.VarName){$(QuoteNode(sym))}())
function varname(expr::Expr)
if Meta.isexpr(expr, :ref)
sym, inds = vsym(expr), vinds(expr)
return :($(AbstractPPL.VarName){$(QuoteNode(sym))}($inds))
head = vsym(expr)
inds = vinds(expr, head)
return :($(AbstractPPL.VarName){$(QuoteNode(head))}($inds))
else
error("Malformed variable name $(expr)!")
end
Expand Down Expand Up @@ -334,37 +345,128 @@ macro vinds(expr::Union{Expr, Symbol})
end


@static if VERSION < v"1.5.0-DEV.666"
_replace_ref_begin_end(ex, withex) = Base.replace_ref_end_!(copy(ex), withex)
_replace_ref_begin_end(ex::Symbol, withex) = Base.replace_ref_end_!(ex, withex)
_index_replacement_for(s) = :($lastindex($s))
_index_replacement_for(s, n) = :($lastindex($s, $n))
else
_replace_ref_begin_end(ex, withex) = Base.replace_ref_begin_end_!(copy(ex), withex)
_replace_ref_begin_end(ex::Symbol, withex) = Base.replace_ref_begin_end_!(ex, withex)
_index_replacement_for(s) = :($firstindex($s)), :($lastindex($s))
_index_replacement_for(s, n) = :($firstindex($s, $n)), :($lastindex($s, $n))
end


"""
vinds(expr)

Return the indexing part of the [`@varname`](@ref)-compatible expression `expr` as an expression
suitable for input of the [`VarName`](@ref) constructor.
suitable for input of the [`VarName`](@ref) constructor (i.e., a tuple of tuples).

## Examples

```jldoctest
julia> vinds(:(x[end]))
:((((lastindex)(x),),))
julia> x = [10, 20, 30]; eval(vinds(:(x[end])))
((3,),)


julia> x = [10 20]; eval(vinds(:(x[1, end])))
((1, 2),)

julia> x = [[1, 2]]; eval(vinds(:(x[1][end])))
((1,), (2,))

julia> x = ([1, 2], ); eval(vinds(:(x[1][end]))) # tuple
((1,), (2,))

julia> x = [fill([[10], [20, 30]], 2, 2, 2)]
if VERSION < v"1.5.0-DEV.666"
eval(vinds(:(x[1][2, end, :][2][end])))
else
eval(vinds(Meta.parse("x[begin][2, end, :][2][end]")))
end
((1,), (2, 2, Colon()), (2,), (2,))

julia> vinds(:(x[1, end]))
:(((1, (lastindex)(x, 2)),))
```
"""
function vinds end
function vinds(expr, head = vsym(expr))
# see https://github.com/JuliaLang/julia/blob/bb5b98e72a151c41471d8cc14cacb495d647fb7f/base/views.jl#L17-L75
indexing = _straighten_indexing(expr)
inds = Expr[] # collection of result indices
partial = head # partial :ref expressions, used in caching
cached_exprs = Vector{Pair{Symbol, Expr}}() # cache for partial expressions going into a let

for ixs in indexing
# S becomes the name of the cached variable
S = (partial == head) ? head : gensym(:S)
used_S = false

nixs = length(ixs)
if nixs == 1
# for linear indexing, we use `lastindex(x)`
ixs[1], used = _replace_ref_begin_end(ixs[1], _index_replacement_for(S))
used_S |= used
elseif nixs > 1
# for cartesian indexing, we need `lastindex(x, i)`
for i in eachindex(ixs)
ixs[i], used = _replace_ref_begin_end(ixs[i], _index_replacement_for(S, i))
used_S |= used
end
end

vinds(expr::Symbol) = Expr(:tuple)
function vinds(expr::Expr)
if Meta.isexpr(expr, :ref)
ex = copy(expr)
@static if VERSION < v"1.5.0-DEV.666"
Base.replace_ref_end!(ex)
if used_S && partial !== head
# cache that expression if we actually used it, and use the new name in the
# partial expression
push!(cached_exprs, S => partial)
partial = Expr(:call, Base.maybeview, S, ixs...)
else
Base.replace_ref_begin_end!(ex)
partial = Expr(:call, Base.maybeview, partial, ixs...)
end
last = Expr(:tuple, ex.args[2:end]...)
init = vinds(ex.args[1]).args
return Expr(:tuple, init..., last)

push!(inds, Expr(:tuple, ixs...))
end

# finally make the tuple of tuples
tuple_expr = Expr(:tuple, inds...)

if length(cached_exprs) == 0
return tuple_expr
else
# construct one big let expression
cached_assignments = [:($S = $partial) for (S, partial) in cached_exprs]
return Expr(:let, Expr(:block, cached_assignments...), tuple_expr)
end
end


"""
_straighten_indexing(expr)

Extract a list of lists of (raw) indices of an iterated `:ref` expression.

```julia
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```julia
```jldoctest

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that work with a non-exported function?

julia> _straighten_indexing(:(x[begin][2, end, :][2][end]))
4-element Array{Array{Any,1},1}:
[:begin]
[2, :end, :(:)]
[2]
[:end]
```
"""
_straighten_indexing(expr::Symbol) = Vector{Any}[]
function _straighten_indexing(expr::Expr)
if Meta.isexpr(expr, :ref)
init = _straighten_indexing(expr.args[1])
last = expr.args[2:end]
return push!(init, last)
else
error("Mis-formed variable name $(expr)!")
end
end