From a8e55bd0a6cb0798b980dbb36be9f3b263d3875d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:10:20 +0100 Subject: [PATCH 01/14] updated SimpleVarInfo impl --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a26516283..88a5ad89f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -138,6 +138,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4e848e291..501aa2185 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -21,15 +21,15 @@ SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) -function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] += logp return vi end @@ -69,8 +69,8 @@ function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple} return left, Distributions.loglikelihood(dist, left) end -# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) +# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) # end function dot_assume( From bfd7c789639df0395e33e0c1bf20557c27f60aa1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:42:50 +0100 Subject: [PATCH 02/14] added eltype impl for SimpleVarInfo --- src/simple_varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 501aa2185..d5ca2fc13 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -62,6 +62,11 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) # HACK: Need to disambiguiate. getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) +# Necessary for `matchingvalue` to work properly. +function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) + return T +end + # Context implementations # Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) From acb15eb9b9525eda0b57036f2b6864623d8ab1d3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:45:38 +0100 Subject: [PATCH 03/14] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d5ca2fc13..12437844e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -63,7 +63,9 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. -function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) +function Base.eltype( + vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} +) return T end From 4828aab2f3ee108f286b461a057f8909c9dfbc4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jul 2021 10:56:52 +0100 Subject: [PATCH 04/14] fixed eltype for SimpleVarInfo --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 12437844e..c88bf0192 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -65,7 +65,7 @@ getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. function Base.eltype( vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} -) +) where {T} return T end @@ -136,3 +136,5 @@ function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names return SimpleVarInfo{T}(NamedTuple{names}(vals)) end + +SimpleVarInfo(model::Model, args...) = SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) From e4f0ad263a862b3157df20fc39e795e88c860a3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 22:50:16 +0100 Subject: [PATCH 05/14] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c88bf0192..147865cc0 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -137,4 +137,6 @@ function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names return SimpleVarInfo{T}(NamedTuple{names}(vals)) end -SimpleVarInfo(model::Model, args...) = SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) +function SimpleVarInfo(model::Model, args...) + return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) +end From ccfd112d3b55d67e8c886bb435358d5b2b3f38dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 22:50:58 +0100 Subject: [PATCH 06/14] initial work on allowing sampling using SimpleVarInfo --- Project.toml | 1 + src/simple_varinfo.jl | 56 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 7cc3db31d..cf8dfab90 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 147865cc0..859ef0540 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,3 +1,5 @@ +using Setfield + """ SimpleVarInfo{NT,T} <: AbstractVarInfo @@ -19,6 +21,8 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) +SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) +SimpleVarInfo() = SimpleVarInfo{Float64}() getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) @@ -69,17 +73,48 @@ function Base.eltype( return T end +function push!!(vi::SimpleVarInfo{Nothing}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} + @set vi.θ = NamedTuple{(sym, )}((value, )) +end +function push!!(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} + @set vi.θ = merge(vi.θ, NamedTuple{(sym, )}((value, ))) +end + # Context implementations -# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. -function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) +function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) + return value, acclogp!!(vi_new, logp) +end + +function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) left = vi[vn] - return left, Distributions.loglikelihood(dist, left) + return left, Distributions.loglikelihood(dist, left), vi +end + +function assume( + rng::Random.AbstractRNG, + sampler::SampleFromPrior, + dist::Distribution, + vn::VarName, + vi::SimpleVarInfo +) + value = init(rng, dist, sampler) + vi = push!!(vi, vn, value, dist, sampler) + vi = settrans!!(vi, false, vn) + return value, Distributions.loglikelihood(dist, value), vi end # function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) # throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) # end +function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. + left .= value + return value, acclogp!!(vi_new, logp) +end + function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, @@ -93,11 +128,11 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri - return Distributions.logpdf(dist, ri) + value = vi[vns] + lp = sum(zip(vns, eachcol(value))) do vn, val + return Distributions.logpdf(dist, val) end - return r, lp + return value, lp, vi end function dot_assume( @@ -112,13 +147,14 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(Distributions.logpdf.(dists, r)) - return r, lp + value = vi[vns] + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing +settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi # Interaction with `VarInfo` SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) From d660433c4a8620a118d19a65caa192ff7baeebbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 23:22:48 +0100 Subject: [PATCH 07/14] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 859ef0540..2796a015c 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -73,11 +73,15 @@ function Base.eltype( return T end -function push!!(vi::SimpleVarInfo{Nothing}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} - @set vi.θ = NamedTuple{(sym, )}((value, )) +function push!!( + vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution +) where {sym} + @set vi.θ = NamedTuple{(sym,)}((value,)) end -function push!!(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(sym, )}((value, ))) +function push!!( + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution +) where {sym} + @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end # Context implementations @@ -96,7 +100,7 @@ function assume( sampler::SampleFromPrior, dist::Distribution, vn::VarName, - vi::SimpleVarInfo + vi::SimpleVarInfo, ) value = init(rng, dist, sampler) vi = push!!(vi, vn, value, dist, sampler) From 90cf754b3a51662b84ebef33a9b277218c0597e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:06:33 +0100 Subject: [PATCH 08/14] add constructor for SimpleVarInfo using model --- src/simple_varinfo.jl | 44 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2796a015c..87d9b0516 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -24,6 +24,29 @@ SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) SimpleVarInfo() = SimpleVarInfo{Float64}() +# Interaction with `VarInfo` +SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} + vals = map(names) do n + let md = getfield(vi.metadata, n) + x = map(enumerate(md.ranges)) do (i, r) + reconstruct(md.dists[i], md.vals[r]) + end + + # TODO: Doesn't support batches of `MultivariateDistribution`? + length(x) == 1 ? x[1] : x + end + end + + return SimpleVarInfo{T}(NamedTuple{names}(vals)) +end + +SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} + _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) + return svi +end + getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) @@ -159,24 +182,3 @@ end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi - -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end - -function SimpleVarInfo(model::Model, args...) - return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) -end From 0ab9d8b40a39c83a2f8ed473ac45306f22775ddd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:08:39 +0100 Subject: [PATCH 09/14] improved leftover to_namedtuple_expr, fixing a bug when used with Zygote --- src/utils.jl | 37 +++++-------------------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..db7faabbd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,39 +44,12 @@ function getargs_tilde(expr::Expr) end end -############################################ -# Julia 1.2 temporary fix - Julia PR 33303 # -############################################ function to_namedtuple_expr(syms, vals=syms) - if length(syms) == 0 - nt = :(NamedTuple()) - else - nt_type = Expr( - :curly, - :NamedTuple, - Expr(:tuple, QuoteNode.(syms)...), - Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...), - ) - nt = Expr(:call, :($(DynamicPPL.namedtuple)), nt_type, Expr(:tuple, vals...)) - end - return nt -end - -if VERSION == v"1.2" - @eval function namedtuple( - ::Type{NamedTuple{names,T}}, args::Tuple - ) where {names,T<:Tuple} - if length(args) != length(names) - throw(ArgumentError("Wrong number of arguments to named tuple constructor.")) - end - # Note T(args) might not return something of type T; e.g. - # Tuple{Type{Float64}}((Float64,)) returns a Tuple{DataType} - return $(Expr(:splatnew, :(NamedTuple{names,T}), :(T(args)))) - end -else - function namedtuple(::Type{NamedTuple{names,T}}, args::Tuple) where {names,T<:Tuple} - return NamedTuple{names,T}(args) - end + length(syms) == 0 && return :(NamedTuple()) + + names_expr = Expr(:tuple, QuoteNode.(syms)...) + vals_expr = Expr(:tuple, vals...) + return :(NamedTuple{$names_expr}($vals_expr)) end ##################################################### From 42ad5524d3a29f650e405079c1d8921d4845b536 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:09:22 +0100 Subject: [PATCH 10/14] bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e17343312..e2c1dd29e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.12.2" +version = "0.12.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From d0a08f694752ccc9fd55cd2cb9f9f53c1ce6c96c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:13:39 +0100 Subject: [PATCH 11/14] fixed some issues and added support for usage of Dict in SimpleVarInfo --- src/simple_varinfo.jl | 67 ++++++++++++++++++++++++++++--------------- src/varinfo.jl | 20 +++++++++++++ 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 87d9b0516..879b40a65 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -21,32 +21,27 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) +SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple()) SimpleVarInfo() = SimpleVarInfo{Float64}() -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end - +# Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) return svi end +# Constructor from `VarInfo`. +function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} + return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +end +function SimpleVarInfo{T}( + vi::VarInfo{<:NamedTuple{names}}, ::Type{D} +) where {T<:Real,names,D} + values = values_as(vi, D) + return SimpleVarInfo{T}(values) +end + getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) @@ -67,9 +62,28 @@ function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} return _getindex(value, inds) end +# `NamedTuple` function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - return _getvalue(vi.θ, Val{sym}(), vn.indexing) + # If `sym` is found in `vi.θ` we assume it will be of correct + # shape to support `getindex` for `vn.indexing`. + # If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`. + # This means that we support both the following cases: + # 1. `x[1]` has been provided by the user and can be assumed to be + # of shape that allows us to call `_getvalue` on it. + # 2. `x[1]` was not provided by the user, e.g. possibly obtained by + # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. + return if haskey(vi.θ, sym) + _getvalue(vi.θ, Val{sym}(), vn.indexing) + else + getproperty(vi.θ, Symbol(vn)) + end end + +# `Dict` +function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName) + return vi.θ[vn] +end + # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) @@ -96,15 +110,22 @@ function Base.eltype( return T end +# `NamedTuple` function push!!( - vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = NamedTuple{(sym,)}((value,)) + @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) + @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) +end + +# `Dict` +function push!!(vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) + vi.θ[vn] = r + return vi end # Context implementations diff --git a/src/varinfo.jl b/src/varinfo.jl index 007e3cf3d..af5cb0622 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1491,3 +1491,23 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, return indices end + +""" + values_as(vi::TypedVarInfo, ::Type{NamedTuple}) + values_as(vi::TypedVarInfo, ::Type{Dict}) + +Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if +""" +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return Dict(iter) +end + +function values_from_metadata(md::Metadata) + return (vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for vn in md.vns) +end From ff75ddc0be85532259ace44d86e8f88436195ac5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:45:03 +0100 Subject: [PATCH 12/14] added docstring and improved indexing behvaior for SimpleVarInfo --- src/simple_varinfo.jl | 96 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 879b40a65..c01263ae1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -6,13 +6,90 @@ using Setfield A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. -Currently only implemented for `NT <: NamedTuple`. +Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. -## Notes +# Notes The major differences between this and `TypedVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` only supports evaluation. +3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either + a) no indexing is used in tilde-statements, or + b) the values have been specified with the corret shapes. + +# Examples +```jldoctest; setup=:(using Distributions, Random) +julia> @model function demo() + x = Vector{Float64}(undef, 2) + for i in eachindex(x) + x[i] ~ Normal() + end + return x + end +demo (generic function with 1 method) + +julia> m = demo(); + +julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); + +julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")` + # and thus accessing these values will be type-unstable and slower. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi +SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044) + +julia> # (×) SLOW!!! + DynamicPPL.getval(vi, @varname(x[1])) +0.14447203090358265 + +julia> # In addtion, we can only access varnames as they appear in the model! + DynamicPPL.getval(vi, @varname(x)) +ERROR: type NamedTuple has no field x +[...] + +julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +ERROR: type NamedTuple has no field x[1:2] +[...] + +julia> # In contrast, if we provide the container for `x`, the `vi` now only + # has the key `x` and we access parts of it using indices. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi +SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654],), -2.0573897507053474) + +julia> # (✓) Vroom, vroom! FAST!!! + DynamicPPL.getval(vi, @varname(x[1])) +-0.6538238172778861 + +julia> # We can also access arbitrary varnames pointing to `x`, e.g. + DynamicPPL.getval(vi, @varname(x)) +2-element Vector{Float64}: + -0.6538238172778861 + 0.10742338922309654 + +julia> DynamicPPL.getval(vi, @varname(x[1:2])) +2-element view(::Vector{Float64}, 1:2) with eltype Float64: + -0.6538238172778861 + 0.10742338922309654 + +julia> # The better way to handle sampling of variables involving indexing + # if one does not know the varnames, is to use a `Dict` as the container instead. + # Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e. + # how they appear in the model. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437, x[2] => -1.382335836121636), -3.4308773745351453) + +julia> # (✓) Sort of fast, but only possible at runtime. + DynamicPPL.getval(vi, @varname(x[1])) +1.1292246244328437 + +julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does + # not directly appear in the model. + DynamicPPL.getval(vi, @varname(x)) +ERROR: KeyError: key x not found +[...] + +julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +ERROR: KeyError: key x[1:2] not found +[...] +``` """ struct SimpleVarInfo{NT,T} <: AbstractVarInfo θ::NT @@ -73,7 +150,7 @@ function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} # 2. `x[1]` was not provided by the user, e.g. possibly obtained by # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. return if haskey(vi.θ, sym) - _getvalue(vi.θ, Val{sym}(), vn.indexing) + maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) else getproperty(vi.θ, Symbol(vn)) end @@ -117,9 +194,16 @@ function push!!( @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) + # If the key is already there, we try to update in place. + return if haskey(vi.θ, sym) + current = _getvalue(vi.θ, Val{sym}(), vn.indexing) + current .= value + vi + else + @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) + end end # `Dict` From d29dd8f54fe3133e1f7813fff895c8a2fd983f0b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:45:36 +0100 Subject: [PATCH 13/14] formatting --- src/simple_varinfo.jl | 16 +++++++++++++--- src/varinfo.jl | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c01263ae1..0e5d70213 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -189,12 +189,20 @@ end # `NamedTuple` function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym,Tuple{}}, + value, + dist::Distribution, + gidset::Set{Selector}, ) where {sym} @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym}, + value, + dist::Distribution, + gidset::Set{Selector}, ) where {sym} # If the key is already there, we try to update in place. return if haskey(vi.θ, sym) @@ -207,7 +215,9 @@ function push!!( end # `Dict` -function push!!(vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!!( + vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) vi.θ[vn] = r return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3628ee199..6b7523fbf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1511,5 +1511,8 @@ function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} end function values_from_metadata(md::Metadata) - return (vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for vn in md.vns) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) end From a72594f058e2203aab66c39ac532df24eae2bbfb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 05:07:41 +0100 Subject: [PATCH 14/14] dont allow sampling with indexing when using SimpleVarInfo with NamedTuple unless shapes are specified --- src/simple_varinfo.jl | 93 +++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0e5d70213..afc176a5d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -17,8 +17,11 @@ The major differences between this and `TypedVarInfo` are: b) the values have been specified with the corret shapes. # Examples -```jldoctest; setup=:(using Distributions, Random) +```jldoctest; setup=:(using Distributions) +julia> using StableRNGs + julia> @model function demo() + m ~ Normal() x = Vector{Float64}(undef, 2) for i in eachindex(x) x[i] ~ Normal() @@ -29,59 +32,46 @@ demo (generic function with 1 method) julia> m = demo(); -julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); - -julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")` - # and thus accessing these values will be type-unstable and slower. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi -SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044) - -julia> # (×) SLOW!!! - DynamicPPL.getval(vi, @varname(x[1])) -0.14447203090358265 - -julia> # In addtion, we can only access varnames as they appear in the model! - DynamicPPL.getval(vi, @varname(x)) -ERROR: type NamedTuple has no field x -[...] +julia> rng = StableRNG(42); -julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) -ERROR: type NamedTuple has no field x[1:2] -[...] +julia> ### Sampling ### + ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); -julia> # In contrast, if we provide the container for `x`, the `vi` now only - # has the key `x` and we access parts of it using indices. +julia> # In the `NamedTuple` version we need to provide the place-holder values for + # the variablse which are using "containers", e.g. `Array`. + # In this case, this means that we need to specify `x` but not `m`. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654],), -2.0573897507053474) +SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931) julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) --0.6538238172778861 +1.6642061055583879 julia> # We can also access arbitrary varnames pointing to `x`, e.g. DynamicPPL.getval(vi, @varname(x)) 2-element Vector{Float64}: - -0.6538238172778861 - 0.10742338922309654 + 1.6642061055583879 + 1.796319600944139 julia> DynamicPPL.getval(vi, @varname(x[1:2])) 2-element view(::Vector{Float64}, 1:2) with eltype Float64: - -0.6538238172778861 - 0.10742338922309654 + 1.6642061055583879 + 1.796319600944139 + +julia> # (×) If we don't provide the container... + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi +ERROR: type NamedTuple has no field x +[...] -julia> # The better way to handle sampling of variables involving indexing - # if one does not know the varnames, is to use a `Dict` as the container instead. - # Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e. - # how they appear in the model. +julia> # If one does not know the varnames, we can use a `Dict` instead. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi -SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437, x[2] => -1.382335836121636), -3.4308773745351453) +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237) julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) -1.1292246244328437 +1.192696983568277 -julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does - # not directly appear in the model. +julia> # In addtion, we can only access varnames as they appear in the model! DynamicPPL.getval(vi, @varname(x)) ERROR: KeyError: key x not found [...] @@ -136,24 +126,15 @@ end function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} # Use `getproperty` instead of `getfield` value = getproperty(nt, sym) + # Note that this will return a `view`, even if the resulting value is 0-dim. + # This makes it possible to call `setindex!` on the result later to update + # in place even in the case where are retrieving a single element, e.g. `x[1]`. return _getindex(value, inds) end # `NamedTuple` -function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - # If `sym` is found in `vi.θ` we assume it will be of correct - # shape to support `getindex` for `vn.indexing`. - # If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`. - # This means that we support both the following cases: - # 1. `x[1]` has been provided by the user and can be assumed to be - # of shape that allows us to call `_getvalue` on it. - # 2. `x[1]` was not provided by the user, e.g. possibly obtained by - # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. - return if haskey(vi.θ, sym) - maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) - else - getproperty(vi.θ, Symbol(vn)) - end +function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} + return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) end # `Dict` @@ -204,14 +185,12 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - # If the key is already there, we try to update in place. - return if haskey(vi.θ, sym) - current = _getvalue(vi.θ, Val{sym}(), vn.indexing) - current .= value - vi - else - @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) - end + # We update in place. + # We need a view into the array, hence we call `_getvalue` directly + # rather than `getval`. + current = _getvalue(vi.θ, Val{sym}(), vn.indexing) + current .= value + return vi end # `Dict`