From 90060551aa574728cb49f135b97f93a5503f667f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 16 May 2021 23:35:19 +0200 Subject: [PATCH 01/16] initial work on adding logdensity definition to Model --- src/compiler.jl | 191 +++++++++++++++++++++++++++++++-- src/context_implementations.jl | 40 ++++--- src/model.jl | 11 +- 3 files changed, 216 insertions(+), 26 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 39203f5ee..099d5f647 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -70,13 +70,19 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) + modelinfo_logπ = deepcopy(modelinfo) # Generate main body modelinfo[:body] = generate_mainbody( mod, modelinfo[:modeldef][:body], warn ) - return build_output(modelinfo, linenumbernode) + # Generate logπ + modelinfo_logπ[:body] = generate_mainbody_logdensity( + mod, modelinfo_logπ[:modeldef][:body], warn + ) + + return build_output(modelinfo, modelinfo_logπ, linenumbernode) end """ @@ -298,14 +304,7 @@ hasmissing(T::Type{<:AbstractArray{TA}}) where {TA <: AbstractArray} = hasmissin hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true hasmissing(T::Type) = false -""" - build_output(modelinfo, linenumbernode) - -Builds the output expression. -""" -function build_output(modelinfo, linenumbernode) - ## Build the anonymous evaluator from the user-provided model definition. - +function build_evaluator(modelinfo) # Remove the name. evaluatordef = deepcopy(modelinfo[:modeldef]) delete!(evaluatordef, :name) @@ -328,6 +327,53 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = modelinfo[:body] + return evaluatordef +end + +function build_logπ(modelinfo) + # Remove the name. + def = deepcopy(modelinfo[:modeldef]) + def[:name] = :logπ + + # Add the internal arguments to the user-specified arguments (positional + keywords). + @gensym T + def[:args] = vcat( + [ + :(__model__::$(DynamicPPL.Model)), + :(__sampler__::$(DynamicPPL.AbstractSampler)), + :(__context__::$(DynamicPPL.AbstractContext)), + :(__variables__), + T + ], + modelinfo[:allargs_exprs], + ) + + # Delete the keyword arguments. + def[:kwargs] = [] + + # Replace the user-provided function body with the version created by DynamicPPL. + + def[:body] = quote + __lp__ = zero($T) + $(modelinfo[:body]) + return __lp__ + end + + return def +end + +""" + build_output(modelinfo, linenumbernode) + +Builds the output expression. +""" +function build_output(modelinfo, modelinfo_logπ, linenumbernode) + ## Build logπ. + logπdef = build_logπ(modelinfo_logπ) + + ## Build the anonymous evaluator from the user-provided model definition. + evaluatordef = build_evaluator(modelinfo) + ## Build the model function. # Extract the named tuple expression of all arguments and the default values. @@ -337,18 +383,20 @@ function build_output(modelinfo, linenumbernode) # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] - @gensym evaluator + @gensym evaluator logπ # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) + $logπ = $(MacroTools.combinedef(logπdef)) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, $defaults_namedtuple, + $logπ ) end @@ -415,3 +463,126 @@ end floatof(::Type{T}) where {T <: Real} = typeof(one(T)/one(T)) floatof(::Type) = Real # fallback if type inference failed + + +##################### +### `logdensity` ### +##################### +generate_mainbody_logdensity(mod, expr, warn) = generate_mainbody_logdensity!(mod, Symbol[], expr, warn) + +generate_mainbody_logdensity!(mod, found, x, warn) = x +function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn) + if sym in DEPRECATED_INTERNALNAMES + newsym = Symbol(:_, sym, :__) + Base.depwarn( + "internal variable `$sym` is deprecated, use `$newsym` instead.", + :generate_mainbody_logdensity!, + ) + return generate_mainbody_logdensity!(mod, found, newsym, warn) + end + + if warn && sym in INTERNALNAMES && sym ∉ found + @warn "you are using the internal variable `$sym`" + push!(found, sym) + end + + return sym +end +function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) + # Do not touch interpolated expressions + expr.head === :$ && return expr.args[1] + + # If it's a macro, we expand it + if Meta.isexpr(expr, :macrocall) + return generate_mainbody_logdensity!(mod, found, macroexpand(mod, expr; recursive=true), warn) + end + + # If it's a return, we instead return `__lp__`. + if Meta.isexpr(expr, :return) + returnbody = Expr(:block, map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...) + return :($(returnbody); return __lp__) + end + + # Modify dotted tilde operators. + args_dottilde = getargs_dottilde(expr) + if args_dottilde !== nothing + L, R = args_dottilde + left = generate_mainbody_logdensity!(mod, found, L, warn) + return generate_dot_tilde_logdensity( + left, + generate_mainbody_logdensity!(mod, found, R, warn), + ) |> Base.remove_linenums! + end + + # Modify tilde operators. + args_tilde = getargs_tilde(expr) + if args_tilde !== nothing + L, R = args_tilde + left = generate_mainbody_logdensity!(mod, found, L, warn) + return generate_tilde_logdensity( + left, + generate_mainbody_logdensity!(mod, found, R, warn), + ) |> Base.remove_linenums! + end + + return Expr(expr.head, map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...) +end + +function generate_tilde_logdensity(left, right) + @gensym tmpright + top = [:($tmpright = $right), + :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} + || throw(ArgumentError($DISTMSG)))] + + if left isa Symbol || left isa Expr + @gensym vn inds isassumption + push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) + + # If it's not present in args of the model, we need to extract it from `__variables__`. + return quote + $(top...) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.tilde_observe)( + __context__, __sampler__, $tmpright, $left, $vn, $inds, nothing + ) + end + end + + # If the LHS is a literal, it is always an observation + return quote + $(top...) + __lp__ += $(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, nothing) + end +end + +function generate_dot_tilde_logdensity(left, right) + @gensym tmpright + top = [:($tmpright = $right), + :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} + || throw(ArgumentError($DISTMSG)))] + + if left isa Symbol || left isa Expr + @gensym vn inds isassumption + push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) + + return quote + $(top...) + $isassumption = $(DynamicPPL.isassumption(left)) || $left === missing + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.dot_tilde_observe)( + __context__, __sampler__, $tmpright, $left, $vn, $inds, nothing + ) + end + end + + # If the LHS is a literal, it is always an observation + return quote + $(top...) + $(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, nothing) + end +end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6b3542acd..4779674c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -87,8 +87,12 @@ and indices; if needed, these can be accessed through this function, though. """ function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) logp = tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end """ @@ -101,8 +105,12 @@ Falls back to `tilde(ctx, sampler, right, left, vi)`. """ function tilde_observe(ctx, sampler, right, left, vi) logp = tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end @@ -148,7 +156,7 @@ function observe( value, vi, ) - increment_num_produce!(vi) + vi === nothing || increment_num_produce!(vi) return Distributions.loglikelihood(dist, value) end @@ -406,8 +414,12 @@ name and indices; if needed, these can be accessed through this function, though """ function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) logp = dot_tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end """ @@ -420,8 +432,12 @@ Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. """ function dot_tilde_observe(ctx, sampler, right, left, vi) logp = dot_tilde(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left + if vi === nothing + return logp + else + acclogp!(vi, logp) + return left + end end function _dot_tilde(sampler, right, left::AbstractArray, vi) @@ -443,7 +459,7 @@ function dot_observe( value::AbstractMatrix, vi, ) - increment_num_produce!(vi) + vi == nothing || increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) @@ -454,7 +470,7 @@ function dot_observe( value::AbstractArray, vi, ) - increment_num_produce!(vi) + vi == nothing || increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) @@ -465,7 +481,7 @@ function dot_observe( value::AbstractArray, vi, ) - increment_num_produce!(vi) + vi == nothing || increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(zip(dists, value)) do (d, v) diff --git a/src/model.jl b/src/model.jl index b0b78f71f..544c29786 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,11 +32,12 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults, Π} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + logπ::Π """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -49,8 +50,9 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbab f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}(name, f, args, defaults) + logπ::Π = identity + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Π} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Π}(name, f, args, defaults, logπ) end end @@ -68,9 +70,10 @@ model with different arguments. f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple = NamedTuple(), + logπ = identity ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, logπ)) end """ From 6414e427a1fddf0eb16d19d8e49a2ee2d244bef5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 03:49:43 +0200 Subject: [PATCH 02/16] added EvaluationContext to avoid custom compilation --- src/context_implementations.jl | 20 +++++++++++++++++++- src/contexts.jl | 15 +++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4779674c5..38c1fd2ad 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -54,6 +54,12 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end +function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi) + value = _getindex(getfield(ctx.θ, getsym(vn)), inds) + tilde_observe(ctx, sampler, right, value, inds, vi) + return value +end + function _tilde(rng, sampler, right, vn::VarName, vi) return assume(rng, sampler, right, vn, vi) @@ -95,6 +101,12 @@ function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) end end +function tilde_observe(ctx::EvaluationContext, sampler, right, left, vname, vinds, vi) + logp = tilde(ctx.ctx, sampler, right, left, vi) + acclogp!(ctx, logp) + return left +end + """ tilde_observe(ctx, sampler, right, left, vi) @@ -113,6 +125,12 @@ function tilde_observe(ctx, sampler, right, left, vi) end end +function tilde_observe(ctx::EvaluationContext, sampler, right, left, vname, vi) + logp = tilde(ctx.ctx, sampler, right, left, vi) + acclogp!(ctx, logp) + return left +end + _tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) @@ -156,7 +174,7 @@ function observe( value, vi, ) - vi === nothing || increment_num_produce!(vi) + (vi isa VarInfo) && increment_num_produce!(vi) return Distributions.loglikelihood(dist, value) end diff --git a/src/contexts.jl b/src/contexts.jl index 2de05a034..8f2d6cd29 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -52,3 +52,18 @@ end function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints/batch_size) end + +struct EvaluationContext{NT, T, Ctx} <: AbstractContext + θ::NT + logp::Base.RefValue{T} + ctx::Ctx +end +EvaluationContext{T}(θ, ctx = DefaultContext()) where {T<:Real} = EvaluationContext{typeof(θ), T, typeof(ctx)}(θ, Ref(zero(T)), ctx) +EvaluationContext(θ, ctx = DefaultContext()) = EvaluationContext{Float64}(θ, ctx) + +function acclogp!(ctx::EvaluationContext, logp) + ctx.logp[] += logp + return ctx +end + +getlogp(ctx::EvaluationContext) = ctx.logp[] From 95879e2e51d63b5e8edca80a897972f7d9c2bfa8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:00:25 +0200 Subject: [PATCH 03/16] introduce SimpleVarInfo instead of accumulating logp in EvalationContext --- src/DynamicPPL.jl | 1 + src/context_implementations.jl | 32 +++++++++++++++++--------------- src/contexts.jl | 14 ++------------ src/simple_varinfo.jl | 21 +++++++++++++++++++++ 4 files changed, 41 insertions(+), 27 deletions(-) create mode 100644 src/simple_varinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 62189fa4c..dbf2b92fc 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -118,6 +118,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 38c1fd2ad..1ae916bd8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -54,9 +54,19 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi) - value = _getindex(getfield(ctx.θ, getsym(vn)), inds) - tilde_observe(ctx, sampler, right, value, inds, vi) +function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo) + value = _getindex(getfield(vi.θ, getsym(vn)), inds) + + # Contexts which have different behavior between `assume` and `observe` we need + # to replace with `DefaultContext` here, otherwise the observation-only + # behavior will be applied to `assume`. + # FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`. + # This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`. + if ctx.ctx isa Union{PriorContext, LikelihoodContext} + tilde_observe(DefaultContext(), sampler, right, value, vn, inds, vi) + else + tilde_observe(ctx, sampler, right, value, vn, inds, vi) + end return value end @@ -81,6 +91,10 @@ end function tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) end +function tilde(ctx::EvaluationContext, sampler, right, left, vi) + return tilde(ctx.ctx, sampler, right, left, vi) +end + """ tilde_observe(ctx, sampler, right, left, vname, vinds, vi) @@ -101,12 +115,6 @@ function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) end end -function tilde_observe(ctx::EvaluationContext, sampler, right, left, vname, vinds, vi) - logp = tilde(ctx.ctx, sampler, right, left, vi) - acclogp!(ctx, logp) - return left -end - """ tilde_observe(ctx, sampler, right, left, vi) @@ -125,12 +133,6 @@ function tilde_observe(ctx, sampler, right, left, vi) end end -function tilde_observe(ctx::EvaluationContext, sampler, right, left, vname, vi) - logp = tilde(ctx.ctx, sampler, right, left, vi) - acclogp!(ctx, logp) - return left -end - _tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 8f2d6cd29..26427a7af 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -53,17 +53,7 @@ function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints/batch_size) end -struct EvaluationContext{NT, T, Ctx} <: AbstractContext - θ::NT - logp::Base.RefValue{T} +struct EvaluationContext{Ctx} <: AbstractContext ctx::Ctx end -EvaluationContext{T}(θ, ctx = DefaultContext()) where {T<:Real} = EvaluationContext{typeof(θ), T, typeof(ctx)}(θ, Ref(zero(T)), ctx) -EvaluationContext(θ, ctx = DefaultContext()) = EvaluationContext{Float64}(θ, ctx) - -function acclogp!(ctx::EvaluationContext, logp) - ctx.logp[] += logp - return ctx -end - -getlogp(ctx::EvaluationContext) = ctx.logp[] +EvaluationContext() = EvaluationContext(DefaultContext()) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..ba347de1e --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,21 @@ +struct SimpleVarInfo{NT, T} <: AbstractVarInfo + θ::NT + logp::Base.RefValue{T} +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ), T}(θ, Ref(zero(T))) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) + +function setlogp!(vi::SimpleVarInfo, logp) + vi.logp[] = logp + return vi +end + +function acclogp!(vi::SimpleVarInfo, logp) + vi.logp[] += logp + return vi +end + +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ From dfacb83f0ad1f698afdd9836f032f14dfd3b0e4c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:00:49 +0200 Subject: [PATCH 04/16] implements logjoint, logprior, and loglikelihood using EvaluationContext --- src/model.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/model.jl b/src/model.jl index 544c29786..0acc48bd1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -193,6 +193,13 @@ function logjoint(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function logjoint(model::Model, θ) + ctx = EvaluationContext(DefaultContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ logprior(model::Model, varinfo::AbstractVarInfo) @@ -205,6 +212,13 @@ function logprior(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function logprior(model::Model, θ) + ctx = EvaluationContext(PriorContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ loglikelihood(model::Model, varinfo::AbstractVarInfo) @@ -217,6 +231,13 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getlogp(varinfo) end +function Distributions.loglikelihood(model::Model, θ) + ctx = EvaluationContext(LikelihoodContext()) + vi = SimpleVarInfo(θ) + model(vi, SampleFromPrior(), ctx) + return getlogp(vi) +end + """ generated_quantities(model::Model, chain::AbstractChains) From 9c4f8806513d531fc4c73eeee0cebc3594381810 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:03:10 +0200 Subject: [PATCH 05/16] added Bijectors.bijector implementation for VarInfo --- src/DynamicPPL.jl | 1 + src/bijectors.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 src/bijectors.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index dbf2b92fc..07ee3bd6c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -125,5 +125,6 @@ include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") +include("bijectors.jl") end # module diff --git a/src/bijectors.jl b/src/bijectors.jl new file mode 100644 index 000000000..0aafa2d50 --- /dev/null +++ b/src/bijectors.jl @@ -0,0 +1,27 @@ +""" + bijector(varinfo::DynamicPPL.VarInfo) + +Returns a `NamedBijector` which can transform different variants of `varinfo`. +""" +@generated function _bijector(md::NamedTuple{names}; tuplify = false) where {names} + expr = Expr(:tuple) + for n in names + e = quote + if length(md.$n.dists) == 1 && md.$n.dists[1] isa $(Distributions.UnivariateDistribution) + $(Bijectors).bijector(md.$n.dists[1]) + elseif tuplify + $(Bijectors.Stacked)(map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges) + else + $(Bijectors.Stacked)(map($(Bijectors).bijector, md.$n.dists), md.$n.ranges) + end + end + push!(expr.args, e) + end + + return quote + bs = NamedTuple{$names}($expr) + return $(Bijectors).NamedBijector(bs) + end +end + +Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...) From cc421bd3df21d529037e7cb8ec0c29b70e8d1875 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:11:31 +0200 Subject: [PATCH 06/16] fixed docstring for bijector --- src/bijectors.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index 0aafa2d50..306d3b468 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -1,8 +1,3 @@ -""" - bijector(varinfo::DynamicPPL.VarInfo) - -Returns a `NamedBijector` which can transform different variants of `varinfo`. -""" @generated function _bijector(md::NamedTuple{names}; tuplify = false) where {names} expr = Expr(:tuple) for n in names @@ -24,4 +19,12 @@ Returns a `NamedBijector` which can transform different variants of `varinfo`. end end +""" + bijector(varinfo::DynamicPPL.VarInfo; tuplify = false) + +Returns a `NamedBijector` which can transform different variants of `varinfo`. + +If `tuplify` is true, then a type-stable bijector will be returned. +""" + Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...) From 5a279f7ab3e2c555410950237082005609a362fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:13:13 +0200 Subject: [PATCH 07/16] specify implementation of tilde_assume for SimpleVarInfo further --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 76dec117f..49b331fa1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -56,7 +56,7 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo) +function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple}) value = _getindex(getfield(vi.θ, getsym(vn)), inds) # Contexts which have different behavior between `assume` and `observe` we need From 4e1bb8c58ac4abcd74444f93d4bac87989e2d305 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 06:08:03 +0200 Subject: [PATCH 08/16] =?UTF-8?q?updated=20impl=20of=20the=20generated=20l?= =?UTF-8?q?og=CF=80=20after=20merge=20with=20master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/compiler.jl | 87 ++++++++++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 37 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 527577ca4..4479e3656 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -561,60 +561,73 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) end function generate_tilde_logdensity(left, right) - @gensym tmpright - top = [:($tmpright = $right), - :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} - || throw(ArgumentError($DISTMSG)))] - - if left isa Symbol || left isa Expr - @gensym vn inds isassumption - push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) - - # If it's not present in args of the model, we need to extract it from `__variables__`. + # If the LHS is a literal, it is always an observation + if !(left isa Symbol || left isa Expr) return quote - $(top...) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $(vsym(left)) = __variables__.$(vsym(left)) - end __lp__ += $(DynamicPPL.tilde_observe)( - __context__, __sampler__, $tmpright, $left, $vn, $inds, nothing + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + __varinfo__, ) end end - # If the LHS is a literal, it is always an observation + @gensym vn inds isassumption + + # If it's not present in args of the model, we need to extract it from `__variables__`. return quote - $(top...) - __lp__ += $(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, nothing) + $vn = $(varname(left)) + $inds = $(vinds(left)) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, + $inds, + nothing + ) end end function generate_dot_tilde_logdensity(left, right) - @gensym tmpright - top = [:($tmpright = $right), - :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} - || throw(ArgumentError($DISTMSG)))] - - if left isa Symbol || left isa Expr - @gensym vn inds isassumption - push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left)))) - + # If the LHS is a literal, it is always an observation + if !(left isa Symbol || left isa Expr) return quote - $(top...) - $isassumption = $(DynamicPPL.isassumption(left)) || $left === missing - if $isassumption - $(vsym(left)) = __variables__.$(vsym(left)) - end __lp__ += $(DynamicPPL.dot_tilde_observe)( - __context__, __sampler__, $tmpright, $left, $vn, $inds, nothing + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + __varinfo__, ) end end - # If the LHS is a literal, it is always an observation + # Otherwise it is determined by the model or its value, + # if the LHS represents an observation + @gensym vn inds isassumption return quote - $(top...) - $(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, nothing) + $vn = $(varname(left)) + $inds = $(vinds(left)) + $isassumption = $(DynamicPPL.isassumption(left)) || $left === missing + if $isassumption + $(vsym(left)) = __variables__.$(vsym(left)) + end + __lp__ += $(DynamicPPL.dot_tilde_observe)( + __context__, + __sampler__, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, + $inds, + nothing + ) end end From 6fbcdc5d94c44461541aa091a373de79b482478a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 06:08:28 +0200 Subject: [PATCH 09/16] fixed impl of dot_tilde_assume for EvaluationContext --- src/context_implementations.jl | 35 +++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 49b331fa1..79624c30a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -64,8 +64,11 @@ function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi: # behavior will be applied to `assume`. # FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`. # This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`. - if ctx.ctx isa Union{PriorContext, LikelihoodContext} - tilde_observe(DefaultContext(), sampler, right, value, vn, inds, vi) + if ctx.ctx isa PriorContext + tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi) + elseif ctx.ctx isa LikelihoodContext + # Need to make it so that this isn't computed. + tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi) else tilde_observe(ctx, sampler, right, value, vn, inds, vi) end @@ -223,6 +226,25 @@ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) return value end +function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vn, inds, vi::SimpleVarInfo{<:NamedTuple}) + value = _getindex(getfield(vi.θ, getsym(vn)), inds) + + # Contexts which have different behavior between `assume` and `observe` we need + # to replace with `DefaultContext` here, otherwise the observation-only + # behavior will be applied to `assume`. + # FIXME: The below doesn't necessarily work for nested contexts, e.g. if `ctx.ctx.ctx isa PriorContext`. + # This is a broader issue though, which should probably be fixed by introducing a `WrapperContext`. + if ctx.ctx isa PriorContext + dot_tilde_observe(LikelihoodContext(), sampler, right, value, vn, inds, vi) + elseif ctx.ctx isa LikelihoodContext + # Need to make it so that this isn't computed. + dot_tilde_observe(PriorContext(), sampler, right, value, vn, inds, vi) + else + dot_tilde_observe(ctx.ctx, sampler, right, value, vn, inds, vi) + end + return value +end + function get_vns_and_dist(dist::NamedDist, var, vn::VarName) return get_vns_and_dist(dist.dist, var, dist.name) end @@ -388,6 +410,9 @@ end function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) end +function dot_tilde(ctx::EvaluationContext, sampler, right, left, vi) + return dot_tilde(ctx.ctx, sampler, right, left, vi) +end """ dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) @@ -445,7 +470,7 @@ function dot_observe( value::AbstractMatrix, vi, ) - vi == nothing || increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) @@ -456,7 +481,7 @@ function dot_observe( value::AbstractArray, vi, ) - vi == nothing || increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) @@ -467,7 +492,7 @@ function dot_observe( value::AbstractArray, vi, ) - vi == nothing || increment_num_produce!(vi) + vi isa VarInfo && increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) From 74d620a7c33444d9b95ae8b86c9cca16db5ae576 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 06:48:30 +0200 Subject: [PATCH 10/16] =?UTF-8?q?renamed=20log=CF=80=20to=20logjoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/compiler.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4479e3656..3c4d97757 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -84,17 +84,17 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) - modelinfo_logπ = deepcopy(modelinfo) + modelinfo_logjoint = deepcopy(modelinfo) # Generate main body modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) - # Generate logπ - modelinfo_logπ[:body] = generate_mainbody_logdensity( - mod, modelinfo_logπ[:modeldef][:body], warn + # Generate logjoint + modelinfo_logjoint[:body] = generate_mainbody_logdensity( + mod, modelinfo_logjoint[:modeldef][:body], warn ) - return build_output(modelinfo, modelinfo_logπ, linenumbernode) + return build_output(modelinfo, modelinfo_logjoint, linenumbernode) end """ @@ -371,10 +371,10 @@ function build_evaluator(modelinfo) return evaluatordef end -function build_logπ(modelinfo) +function build_logjoint(modelinfo) # Remove the name. def = deepcopy(modelinfo[:modeldef]) - def[:name] = :logπ + def[:name] = :logjoint # Add the internal arguments to the user-specified arguments (positional + keywords). @gensym T @@ -408,9 +408,9 @@ end Builds the output expression. """ -function build_output(modelinfo, modelinfo_logπ, linenumbernode) - ## Build logπ. - logπdef = build_logπ(modelinfo_logπ) +function build_output(modelinfo, modelinfo_logjoint, linenumbernode) + ## Build logjoint. + logjointdef = build_logjoint(modelinfo_logjoint) ## Build the anonymous evaluator from the user-provided model definition. evaluatordef = build_evaluator(modelinfo) @@ -424,20 +424,20 @@ function build_output(modelinfo, modelinfo_logπ, linenumbernode) # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] - @gensym evaluator logπ + @gensym evaluator logjoint # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - $logπ = $(MacroTools.combinedef(logπdef)) + $logjoint = $(MacroTools.combinedef(logjointdef)) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, $defaults_namedtuple, - $logπ + $logjoint ) end From ee62924a7c7f0157cd70f581b202b724ae0a9b8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 06:48:39 +0200 Subject: [PATCH 11/16] =?UTF-8?q?renamed=20log=CF=80=20to=20logjoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/model.jl b/src/model.jl index 0b145a9e7..a55f2d108 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,12 +32,12 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults, Π} <: AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults, Logjoint} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} - logπ::Π + logjoint::Logjoint """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -50,9 +50,9 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults, Π} <: AbstractPr f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - logπ::Π = identity - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Π} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Π}(name, f, args, defaults, logπ) + logjoint::Logjoint = identity + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Logjoint} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Logjoint}(name, f, args, defaults, logjoint) end end @@ -70,10 +70,10 @@ model with different arguments. f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple = NamedTuple(), - logπ = identity + logjoint = identity ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults, logπ)) + return :(Model{$missings}(name, f, args, defaults, logjoint)) end """ From 13045ad051c642a91f5b6ba9956fa91d1be5889d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:49:44 +0100 Subject: [PATCH 12/16] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors.jl | 9 +++++--- src/compiler.jl | 40 +++++++++++++++++++--------------- src/context_implementations.jl | 16 +++++++++++--- src/simple_varinfo.jl | 4 ++-- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index 306d3b468..08c242038 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -1,11 +1,14 @@ -@generated function _bijector(md::NamedTuple{names}; tuplify = false) where {names} +@generated function _bijector(md::NamedTuple{names}; tuplify=false) where {names} expr = Expr(:tuple) for n in names e = quote - if length(md.$n.dists) == 1 && md.$n.dists[1] isa $(Distributions.UnivariateDistribution) + if length(md.$n.dists) == 1 && + md.$n.dists[1] isa $(Distributions.UnivariateDistribution) $(Bijectors).bijector(md.$n.dists[1]) elseif tuplify - $(Bijectors.Stacked)(map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges) + $(Bijectors.Stacked)( + map($(Bijectors).bijector, tuple(md.$n.dists...)), md.$n.ranges + ) else $(Bijectors.Stacked)(map($(Bijectors).bijector, md.$n.dists), md.$n.ranges) end diff --git a/src/compiler.jl b/src/compiler.jl index 3c4d97757..58ea7ab7f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -384,7 +384,7 @@ function build_logjoint(modelinfo) :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), :(__variables__), - T + T, ], modelinfo[:allargs_exprs], ) @@ -393,7 +393,6 @@ function build_logjoint(modelinfo) def[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - def[:body] = quote __lp__ = zero($T) $(modelinfo[:body]) @@ -496,11 +495,12 @@ end floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T)) floatof(::Type) = Real # fallback if type inference failed - ##################### ### `logdensity` ### ##################### -generate_mainbody_logdensity(mod, expr, warn) = generate_mainbody_logdensity!(mod, Symbol[], expr, warn) +function generate_mainbody_logdensity(mod, expr, warn) + return generate_mainbody_logdensity!(mod, Symbol[], expr, warn) +end generate_mainbody_logdensity!(mod, found, x, warn) = x function generate_mainbody_logdensity!(mod, found, sym::Symbol, warn) @@ -526,12 +526,17 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody_logdensity!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody_logdensity!( + mod, found, macroexpand(mod, expr; recursive=true), warn + ) end # If it's a return, we instead return `__lp__`. if Meta.isexpr(expr, :return) - returnbody = Expr(:block, map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...) + returnbody = Expr( + :block, + map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., + ) return :($(returnbody); return __lp__) end @@ -540,10 +545,9 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde left = generate_mainbody_logdensity!(mod, found, L, warn) - return generate_dot_tilde_logdensity( - left, - generate_mainbody_logdensity!(mod, found, R, warn), - ) |> Base.remove_linenums! + return Base.remove_linenums!(generate_dot_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + )) end # Modify tilde operators. @@ -551,13 +555,15 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) if args_tilde !== nothing L, R = args_tilde left = generate_mainbody_logdensity!(mod, found, L, warn) - return generate_tilde_logdensity( - left, - generate_mainbody_logdensity!(mod, found, R, warn), - ) |> Base.remove_linenums! + return Base.remove_linenums!(generate_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + )) end - return Expr(expr.head, map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody_logdensity!(mod, found, x, warn), expr.args)..., + ) end function generate_tilde_logdensity(left, right) @@ -591,7 +597,7 @@ function generate_tilde_logdensity(left, right) $left, $vn, $inds, - nothing + nothing, ) end end @@ -627,7 +633,7 @@ function generate_dot_tilde_logdensity(left, right) $left, $vn, $inds, - nothing + nothing, ) end end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 79624c30a..d0c8a3e4a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -56,7 +56,9 @@ function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) return value end -function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple}) +function tilde_assume( + rng, ctx::EvaluationContext, sampler, right, vn, inds, vi::SimpleVarInfo{<:NamedTuple} +) value = _getindex(getfield(vi.θ, getsym(vn)), inds) # Contexts which have different behavior between `assume` and `observe` we need @@ -75,7 +77,6 @@ function tilde_assume(rng, ctx::EvaluationContext, sampler, right, vn, inds, vi: return value end - function _tilde(rng, sampler, right, vn::VarName, vi) return assume(rng, sampler, right, vn, vi) end @@ -226,7 +227,16 @@ function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) return value end -function dot_tilde_assume(rng, ctx::EvaluationContext, sampler, right, left, vn, inds, vi::SimpleVarInfo{<:NamedTuple}) +function dot_tilde_assume( + rng, + ctx::EvaluationContext, + sampler, + right, + left, + vn, + inds, + vi::SimpleVarInfo{<:NamedTuple}, +) value = _getindex(getfield(vi.θ, getsym(vn)), inds) # Contexts which have different behavior between `assume` and `observe` we need diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ba347de1e..6d9fb72fd 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,9 +1,9 @@ -struct SimpleVarInfo{NT, T} <: AbstractVarInfo +struct SimpleVarInfo{NT,T} <: AbstractVarInfo θ::NT logp::Base.RefValue{T} end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ), T}(θ, Ref(zero(T))) +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) function setlogp!(vi::SimpleVarInfo, logp) From 290bd38ed02dd465fbcea3f031e5823ce53e3cb1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 05:50:27 +0100 Subject: [PATCH 13/16] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 58ea7ab7f..f631a6b70 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -436,7 +436,7 @@ function build_output(modelinfo, modelinfo_logjoint, linenumbernode) $evaluator, $allargs_namedtuple, $defaults_namedtuple, - $logjoint + $logjoint, ) end From d65337d696303a0ea951bf7aa8fea642d15fb894 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 06:03:49 +0100 Subject: [PATCH 14/16] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 16 ++++++++++------ src/model.jl | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f631a6b70..14ab36e5a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -545,9 +545,11 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) if args_dottilde !== nothing L, R = args_dottilde left = generate_mainbody_logdensity!(mod, found, L, warn) - return Base.remove_linenums!(generate_dot_tilde_logdensity( - left, generate_mainbody_logdensity!(mod, found, R, warn) - )) + return Base.remove_linenums!( + generate_dot_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + ), + ) end # Modify tilde operators. @@ -555,9 +557,11 @@ function generate_mainbody_logdensity!(mod, found, expr::Expr, warn) if args_tilde !== nothing L, R = args_tilde left = generate_mainbody_logdensity!(mod, found, L, warn) - return Base.remove_linenums!(generate_tilde_logdensity( - left, generate_mainbody_logdensity!(mod, found, R, warn) - )) + return Base.remove_linenums!( + generate_tilde_logdensity( + left, generate_mainbody_logdensity!(mod, found, R, warn) + ), + ) end return Expr( diff --git a/src/model.jl b/src/model.jl index a55f2d108..c83d1f57b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,7 +32,8 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults, Logjoint} <: AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Logjoint} <: + AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} From e7fbc15972cb9ef10e68035350a1da570186da83 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 11:06:38 +0200 Subject: [PATCH 15/16] fixed type-instability in generated logjoint --- src/compiler.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index f631a6b70..24aafaf61 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -384,11 +384,14 @@ function build_logjoint(modelinfo) :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), :(__variables__), - T, ], modelinfo[:allargs_exprs], + [Expr(:kw, :(::Type{$T}), :Float64), ] ) + # Add the type-parameter. + def[:whereparams] = (def[:whereparams]..., T) + # Delete the keyword arguments. def[:kwargs] = [] From ad1e8527f17c9d34f55442a8a0caf6d5328e6cf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 20 May 2021 10:08:16 +0100 Subject: [PATCH 16/16] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bijectors.jl b/src/bijectors.jl index 08c242038..8a1701455 100644 --- a/src/bijectors.jl +++ b/src/bijectors.jl @@ -23,11 +23,10 @@ end """ - bijector(varinfo::DynamicPPL.VarInfo; tuplify = false) + bijector(varinfo::VarInfo; tuplify=false) Returns a `NamedBijector` which can transform different variants of `varinfo`. If `tuplify` is true, then a type-stable bijector will be returned. """ - Bijectors.bijector(vi::TypedVarInfo; kwargs...) = _bijector(vi.metadata; kwargs...)