diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..d7e87f50d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -101,7 +101,8 @@ export AbstractVarInfo, pointwise_loglikelihoods, # Convenience macros @addlogprob!, - @submodel + @submodel, + @observe # Reexport using Distributions: loglikelihood diff --git a/src/compiler.jl b/src/compiler.jl index 91fe78e2b..2848f8092 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -15,14 +15,14 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol,Expr}) +function isassumption(expr::Union{Symbol,Expr}; check_inargs=true) vn = gensym(:vn) return quote let $vn = $(varname(expr)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || + if ($check_inargs && !$(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) true else @@ -243,10 +243,12 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +function generate_mainbody(mod, expr, warn; kwargs...) + return generate_mainbody!(mod, Symbol[], expr, warn; kwargs...) +end -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found, x, warn; kwargs...) = x +function generate_mainbody!(mod, found, sym::Symbol, warn; kwargs...) if sym in DEPRECATED_INTERNALNAMES newsym = Symbol(:_, sym, :__) Base.depwarn( @@ -263,13 +265,15 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody!( + mod, found, macroexpand(mod, expr; recursive=true), warn; tilde_kwargs... + ) end # Modify dotted tilde operators. @@ -278,8 +282,9 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_dottilde return Base.remove_linenums!( generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn; tilde_kwargs...), + generate_mainbody!(mod, found, R, warn; tilde_kwargs...); + tilde_kwargs..., ), ) end @@ -290,13 +295,30 @@ function generate_mainbody!(mod, found, expr::Expr, warn) L, R = args_tilde return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found, L, warn; tilde_kwargs...), + generate_mainbody!(mod, found, R, warn; tilde_kwargs...); + tilde_kwargs..., ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found, x, warn; tilde_kwargs...), expr.args)..., + ) +end + +""" + generate_tilde(left, right) + +Generate an `observe` expression for literals, e.g. `1.0` and `[1.0, ]`. +""" +function generate_tilde_literal(left, right) + return quote + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + ) + end end """ @@ -305,15 +327,9 @@ end Generate an `observe` expression for data variables and `assume` expression for parameter variables. """ -function generate_tilde(left, right) +function generate_tilde(left, right; check_inargs=true) # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation @@ -321,7 +337,7 @@ function generate_tilde(left, right) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( __context__, @@ -349,15 +365,9 @@ end Generate the expression that replaces `left .~ right` in the model body. """ -function generate_dot_tilde(left, right) +function generate_dot_tilde(left, right; check_inargs=true) # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation @@ -365,7 +375,7 @@ function generate_dot_tilde(left, right) return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, diff --git a/src/utils.jl b/src/utils.jl index db7faabbd..fb40e80f4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,6 +13,56 @@ macro addlogprob!(ex) end end +""" + @observe x ~ dist + @observe x .~ dist + +Force this `~` statement to be an observe-statement unless `x` is `missing`, +effectively circumventing the check as to whether `x` is the arguments. + +Only usable within the body of [@model](@ref). + +# Examples + +```jldoctest; setup = :(using Distributions) +julia> @model function demo() + x = 1.0 + @observe x ~ Normal() + + return getlogp(__varinfo__) + end; + +julia> demo()() == logpdf(Normal(), 1.0) +true + +julia> @model function demo() + x = [1.0, ] + @observe x .~ Normal() + + return getlogp(__varinfo__) + end; + +julia> demo()() == logpdf(Normal(), 1.0) +true + +julia> @model function demo(args) + x = args.x + @observe x ~ Normal() + + return getlogp(__varinfo__) + end; + +julia> demo((x = 1.0, ))() == logpdf(Normal(), 1.0) +true + +julia> VarInfo(demo((x = missing, )))[@varname(x)] !== missing +true +``` +""" +macro observe(ex) + return esc(generate_mainbody(__module__, ex, false; check_inargs=false)) +end + """ getargs_dottilde(x)