Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.6"
version = "0.23.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.5.3"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
Expand Down
42 changes: 37 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,34 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

"""
need_concretize(expr)
Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or
requires a dynamic lens.
# Examples
```jldoctest; setup=:(using Setfield)
julia> DynamicPPL.need_concretize(:(x[1, :]))
true
julia> DynamicPPL.need_concretize(:(x[1, end]))
true
julia> DynamicPPL.need_concretize(:(x[1, 1]))
false
"""
function need_concretize(expr)
return Setfield.need_dynamic_lens(expr) || begin
flag = false
MacroTools.postwalk(expr) do ex
# Concretise colon by default
ex == :(:) && (flag = true) && return ex
end
flag
end
end

"""
isassumption(expr[, vn])
Expand All @@ -16,10 +45,13 @@ When `expr` is not an expression or symbol (i.e., a literal), this expands to `f
If `vn` is specified, it will be assumed to refer to a expression which
evaluates to a `VarName`, and this will be used in the subsequent checks.
If `vn` is not specified, `AbstractPPL.drop_escape(varname(expr))` will be
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
"""
function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr)))
function isassumption(
expr::Union{Expr,Symbol},
vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))),
)
return quote
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
Expand Down Expand Up @@ -194,7 +226,7 @@ function unwrap_right_left_vns(
# for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`,
# and we therefore add the `Colon()` below.
vns = map(axes(left, 2)) do i
return vn Setfield.IndexLens((Colon(), i))
return AbstractPPL.concretize(vn Setfield.IndexLens((Colon(), i)), left)
end
return unwrap_right_left_vns(right, left, vns)
end
Expand Down Expand Up @@ -372,7 +404,7 @@ function generate_tilde(left, right)
return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left))), $dist
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
Expand Down Expand Up @@ -433,7 +465,7 @@ function generate_dot_tilde(left, right)
@gensym vn isassumption value
return quote
$vn = $(DynamicPPL.resolve_varnames)(
$(AbstractPPL.drop_escape(varname(left))), $right
$(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $(DynamicPPL.isfixed(left, vn))
Expand Down
3 changes: 2 additions & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ function logprior_true_with_logabsdet_jacobian(
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)})
return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)]
s = zeros(1, 2) # used for varname concretization only
return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)]
end

@model function demo_assume_matrix_dot_observe_matrix(
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,12 @@ function BangBang.possible(
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end
function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer
) where {C<:AbstractMatrix,T<:AbstractVector}
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end

# HACK(torfjelde): This makes it so it works on iterators, etc. by default.
# TODO(torfjelde): Do better.
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.5"
AbstractPPL = "0.6"
Bijectors = "0.13"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Expand Down
2 changes: 1 addition & 1 deletion test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
@test inspace(@varname(z[1][:]), space)
@test inspace(@varname(z[1][2:3:10]), space)
@test inspace(@varname(M[[2, 3], 1]), space)
@test inspace(@varname(M[:, 1:4]), space)
@test_throws ErrorException inspace(@varname(M[:, 1:4]), space)
@test inspace(@varname(M[1, [2, 4, 6]]), space)
@test !inspace(@varname(z[2]), space)
@test !inspace(@varname(z), space)
Expand Down