diff --git a/Project.toml b/Project.toml index a62bf148c..89abcbfcf 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 77d271923..a2ed246b4 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -34,6 +34,7 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + SimpleVarInfo, push!!, empty!!, getlogp, @@ -135,6 +136,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..afc176a5d --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,278 @@ +using Setfield + +""" + SimpleVarInfo{NT,T} <: AbstractVarInfo + +A simple wrapper of the parameters with a `logp` field for +accumulation of the logdensity. + +Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. + +# Notes +The major differences between this and `TypedVarInfo` are: +1. `SimpleVarInfo` does not require linearization. +2. `SimpleVarInfo` can use more efficient bijectors. +3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either + a) no indexing is used in tilde-statements, or + b) the values have been specified with the corret shapes. + +# Examples +```jldoctest; setup=:(using Distributions) +julia> using StableRNGs + +julia> @model function demo() + m ~ Normal() + x = Vector{Float64}(undef, 2) + for i in eachindex(x) + x[i] ~ Normal() + end + return x + end +demo (generic function with 1 method) + +julia> m = demo(); + +julia> rng = StableRNG(42); + +julia> ### Sampling ### + ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); + +julia> # In the `NamedTuple` version we need to provide the place-holder values for + # the variablse which are using "containers", e.g. `Array`. + # In this case, this means that we need to specify `x` but not `m`. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi +SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931) + +julia> # (✓) Vroom, vroom! FAST!!! + DynamicPPL.getval(vi, @varname(x[1])) +1.6642061055583879 + +julia> # We can also access arbitrary varnames pointing to `x`, e.g. + DynamicPPL.getval(vi, @varname(x)) +2-element Vector{Float64}: + 1.6642061055583879 + 1.796319600944139 + +julia> DynamicPPL.getval(vi, @varname(x[1:2])) +2-element view(::Vector{Float64}, 1:2) with eltype Float64: + 1.6642061055583879 + 1.796319600944139 + +julia> # (×) If we don't provide the container... + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi +ERROR: type NamedTuple has no field x +[...] + +julia> # If one does not know the varnames, we can use a `Dict` instead. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237) + +julia> # (✓) Sort of fast, but only possible at runtime. + DynamicPPL.getval(vi, @varname(x[1])) +1.192696983568277 + +julia> # In addtion, we can only access varnames as they appear in the model! + DynamicPPL.getval(vi, @varname(x)) +ERROR: KeyError: key x not found +[...] + +julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +ERROR: KeyError: key x[1:2] not found +[...] +``` +""" +struct SimpleVarInfo{NT,T} <: AbstractVarInfo + θ::NT + logp::T +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) +SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple()) +SimpleVarInfo() = SimpleVarInfo{Float64}() + +# Constructor from `Model`. +SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} + _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) + return svi +end + +# Constructor from `VarInfo`. +function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} + return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +end +function SimpleVarInfo{T}( + vi::VarInfo{<:NamedTuple{names}}, ::Type{D} +) where {T<:Real,names,D} + values = values_as(vi, D) + return SimpleVarInfo{T}(values) +end + +getlogp(vi::SimpleVarInfo) = vi.logp +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) + +function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) + vi.logp[] = logp + return vi +end + +function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) + vi.logp[] += logp + return vi +end + +function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} + # Use `getproperty` instead of `getfield` + value = getproperty(nt, sym) + # Note that this will return a `view`, even if the resulting value is 0-dim. + # This makes it possible to call `setindex!` on the result later to update + # in place even in the case where are retrieving a single element, e.g. `x[1]`. + return _getindex(value, inds) +end + +# `NamedTuple` +function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} + return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) +end + +# `Dict` +function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName) + return vi.θ[vn] +end + +# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than +# just `Vector`. +getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) +# To disambiguiate. +getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) + +haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) + +istrans(::SimpleVarInfo, vn::VarName) = false + +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +# TODO: Should we do better? +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ +getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) +getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) +# HACK: Need to disambiguiate. +getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) + +# Necessary for `matchingvalue` to work properly. +function Base.eltype( + vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} +) where {T} + return T +end + +# `NamedTuple` +function push!!( + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym,Tuple{}}, + value, + dist::Distribution, + gidset::Set{Selector}, +) where {sym} + @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) +end +function push!!( + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym}, + value, + dist::Distribution, + gidset::Set{Selector}, +) where {sym} + # We update in place. + # We need a view into the array, hence we call `_getvalue` directly + # rather than `getval`. + current = _getvalue(vi.θ, Val{sym}(), vn.indexing) + current .= value + return vi +end + +# `Dict` +function push!!( + vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) + vi.θ[vn] = r + return vi +end + +# Context implementations +function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) + return value, acclogp!!(vi_new, logp) +end + +function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) + left = vi[vn] + return left, Distributions.loglikelihood(dist, left), vi +end + +function assume( + rng::Random.AbstractRNG, + sampler::SampleFromPrior, + dist::Distribution, + vn::VarName, + vi::SimpleVarInfo, +) + value = init(rng, dist, sampler) + vi = push!!(vi, vn, value, dist, sampler) + vi = settrans!!(vi, false, vn) + return value, Distributions.loglikelihood(dist, value), vi +end + +# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) +# end + +function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. + left .= value + return value, acclogp!!(vi_new, logp) +end + +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::SimpleVarInfo, +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + value = vi[vns] + lp = sum(zip(vns, eachcol(value))) do vn, val + return Distributions.logpdf(dist, val) + end + return value, lp, vi +end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi::SimpleVarInfo{<:NamedTuple}, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + value = vi[vns] + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi +end + +# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. +increment_num_produce!(::SimpleVarInfo) = nothing +settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi diff --git a/src/varinfo.jl b/src/varinfo.jl index b03970c43..6b7523fbf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1493,3 +1493,26 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, return indices end + +""" + values_as(vi::TypedVarInfo, ::Type{NamedTuple}) + values_as(vi::TypedVarInfo, ::Type{Dict}) + +Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if +""" +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return Dict(iter) +end + +function values_from_metadata(md::Metadata) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) +end