Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ export AbstractVarInfo,
pointwise_loglikelihoods,
# Convenience macros
@addlogprob!,
@submodel
@submodel,
@observe

# Reexport
using Distributions: loglikelihood
Expand Down
70 changes: 40 additions & 30 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
function isassumption(expr::Union{Symbol,Expr}; check_inargs=true)
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
if ($check_inargs && !$(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
Expand Down Expand Up @@ -243,10 +243,12 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
function generate_mainbody(mod, expr, warn; kwargs...)
return generate_mainbody!(mod, Symbol[], expr, warn; kwargs...)
end

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found, x, warn; kwargs...) = x
function generate_mainbody!(mod, found, sym::Symbol, warn; kwargs...)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
Expand All @@ -263,13 +265,15 @@ function generate_mainbody!(mod, found, sym::Symbol, warn)

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found, expr::Expr, warn; tilde_kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the new kwargs always tilde_kwargs? If yes, I'd suggest consistently using the latter name, also in the above methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is right now yes

# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
return generate_mainbody!(
mod, found, macroexpand(mod, expr; recursive=true), warn; tilde_kwargs...
)
end

# Modify dotted tilde operators.
Expand All @@ -278,8 +282,9 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_dottilde
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn; tilde_kwargs...),
generate_mainbody!(mod, found, R, warn; tilde_kwargs...);
tilde_kwargs...,
),
)
end
Expand All @@ -290,13 +295,30 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
L, R = args_tilde
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found, L, warn; tilde_kwargs...),
generate_mainbody!(mod, found, R, warn; tilde_kwargs...);
tilde_kwargs...,
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found, x, warn; tilde_kwargs...), expr.args)...,
)
end

"""
generate_tilde(left, right)

Generate an `observe` expression for literals, e.g. `1.0` and `[1.0, ]`.
"""
function generate_tilde_literal(left, right)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end

"""
Expand All @@ -305,23 +327,17 @@ end
Generate an `observe` expression for data variables and `assume` expression for parameter
variables.
"""
function generate_tilde(left, right)
function generate_tilde(left, right; check_inargs=true)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
Expand Down Expand Up @@ -349,23 +365,17 @@ end

Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right)
function generate_dot_tilde(left, right; check_inargs=true)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
isliteral(left) && return generate_tilde_literal(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left; check_inargs=check_inargs))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
Expand Down
50 changes: 50 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,56 @@ macro addlogprob!(ex)
end
end

"""
@observe x ~ dist
@observe x .~ dist

Force this `~` statement to be an observe-statement unless `x` is `missing`,
effectively circumventing the check as to whether `x` is the arguments.

Only usable within the body of [@model](@ref).

# Examples

```jldoctest; setup = :(using Distributions)
julia> @model function demo()
x = 1.0
@observe x ~ Normal()

return getlogp(__varinfo__)
end;

julia> demo()() == logpdf(Normal(), 1.0)
true

julia> @model function demo()
x = [1.0, ]
@observe x .~ Normal()

return getlogp(__varinfo__)
end;

julia> demo()() == logpdf(Normal(), 1.0)
true

julia> @model function demo(args)
x = args.x
@observe x ~ Normal()

return getlogp(__varinfo__)
end;

julia> demo((x = 1.0, ))() == logpdf(Normal(), 1.0)
true

julia> VarInfo(demo((x = missing, )))[@varname(x)] !== missing
true
```
"""
macro observe(ex)
return esc(generate_mainbody(__module__, ex, false; check_inargs=false))
end

"""
getargs_dottilde(x)

Expand Down