diff --git a/Project.toml b/Project.toml index 3251939c3..6b8abb913 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.6" +version = "0.23.7" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0, 4" -AbstractPPL = "0.5.3" +AbstractPPL = "0.6" BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" diff --git a/src/compiler.jl b/src/compiler.jl index ffdcd4755..96e98938b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,5 +1,34 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +""" + need_concretize(expr) + +Return `true` if `expr` needs to be concretized, i.e., if it contains a colon `:` or +requires a dynamic lens. + +# Examples + +```jldoctest; setup=:(using Setfield) +julia> DynamicPPL.need_concretize(:(x[1, :])) +true + +julia> DynamicPPL.need_concretize(:(x[1, end])) +true + +julia> DynamicPPL.need_concretize(:(x[1, 1])) +false +""" +function need_concretize(expr) + return Setfield.need_dynamic_lens(expr) || begin + flag = false + MacroTools.postwalk(expr) do ex + # Concretise colon by default + ex == :(:) && (flag = true) && return ex + end + flag + end +end + """ isassumption(expr[, vn]) @@ -16,10 +45,13 @@ When `expr` is not an expression or symbol (i.e., a literal), this expands to `f If `vn` is specified, it will be assumed to refer to a expression which evaluates to a `VarName`, and this will be used in the subsequent checks. -If `vn` is not specified, `AbstractPPL.drop_escape(varname(expr))` will be +If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr))) +function isassumption( + expr::Union{Expr,Symbol}, + vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), +) return quote if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: @@ -194,7 +226,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return vn ∘ Setfield.IndexLens((Colon(), i)) + return AbstractPPL.concretize(vn ∘ Setfield.IndexLens((Colon(), i)), left) end return unwrap_right_left_vns(right, left, vns) end @@ -372,7 +404,7 @@ function generate_tilde(left, right) return quote $dist = $right $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left))), $dist + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) @@ -433,7 +465,7 @@ function generate_dot_tilde(left, right) @gensym vn isassumption value return quote $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left))), $right + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $right ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) diff --git a/src/test_utils.jl b/src/test_utils.jl index 394341c04..17898b9fe 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -543,7 +543,8 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] + s = zeros(1, 2) # used for varname concretization only + return [@varname(s[:, 1], true), @varname(s[:, 2], true), @varname(m)] end @model function demo_assume_matrix_dot_observe_matrix( diff --git a/src/utils.jl b/src/utils.jl index abd7b9da0..5c129f4a4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -500,6 +500,12 @@ function BangBang.possible( return BangBang.implements(setindex!, C) && promote_type(eltype(C), eltype(T)) <: eltype(C) end +function BangBang.possible( + ::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer +) where {C<:AbstractMatrix,T<:AbstractVector} + return BangBang.implements(setindex!, C) && + promote_type(eltype(C), eltype(T)) <: eltype(C) +end # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. diff --git a/test/Project.toml b/test/Project.toml index 1f3e6e4fe..b36a7e23a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0, 4" -AbstractPPL = "0.5" +AbstractPPL = "0.6" Bijectors = "0.13" Distributions = "0.25" DistributionsAD = "0.6.3" diff --git a/test/varinfo.jl b/test/varinfo.jl index de86a9e05..35ab30dcd 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -55,7 +55,7 @@ @test inspace(@varname(z[1][:]), space) @test inspace(@varname(z[1][2:3:10]), space) @test inspace(@varname(M[[2, 3], 1]), space) - @test inspace(@varname(M[:, 1:4]), space) + @test_throws ErrorException inspace(@varname(M[:, 1:4]), space) @test inspace(@varname(M[1, [2, 4, 6]]), space) @test !inspace(@varname(z[2]), space) @test !inspace(@varname(z), space)