Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ export AbstractVarInfo,
vectorize,
# Model
Model,
getmissings,
getargnames,
getargumentnames,
generated_quantities,
# Samplers
Sampler,
Expand Down
188 changes: 85 additions & 103 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,6 @@
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

"""
isassumption(expr)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.

Let `expr` be `:(x[1])`. It is an assumption in the following cases:
1. `x` is not among the input data to the model,
2. `x` is among the input data to the model but with a value `missing`, or
3. `x` is among the input data to the model with a value other than missing,
but `x[1] === missing`.

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
# Evaluate the LHS
$expr === missing
end
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

"""
isliteral(expr)
Expand Down Expand Up @@ -137,8 +103,14 @@ end
function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
# Generate main body and find all variable symbols
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should generate_mainbody get a new name? I figured there is no need to traverse the expression twice, so I just added the varname extraction here...

mod, modelinfo[:modeldef][:body], warn
)

# extract observations from that
modelinfo[:obsnames] = modelinfo[:allargs_syms] ∩ modelinfo[:varnames]
modelinfo[:latentnames] = setdiff(modelinfo[:varnames], modelinfo[:allargs_syms])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this will be incorrect for submodels and other similar non-top-level assume statements 😕


return build_output(modelinfo, linenumbernode)
end
Expand Down Expand Up @@ -167,8 +139,7 @@ function build_model_info(input_expr)
modelinfo = Dict(
:allargs_exprs => [],
:allargs_syms => [],
:allargs_namedtuple => NamedTuple(),
:defaults_namedtuple => NamedTuple(),
:allargs_defaults => [],
:modeldef => modeldef,
)
return modelinfo
Expand All @@ -177,17 +148,18 @@ function build_model_info(input_expr)
# Extract the positional and keyword arguments from the model definition.
allargs = vcat(modeldef[:args], modeldef[:kwargs])

# Split the argument expressions and the default values.
allargs_exprs_defaults = map(allargs) do arg
MacroTools.@match arg begin
# Split the argument expressions and the default values, by unzipping allargs, taking care of
# the empty case
allargs_exprs, allargs_defaults = foldl(allargs; init=([], [])) do (ae, ad), arg
(expr, default) = MacroTools.@match arg begin
(x_ = val_) => (x, val)
x_ => (x, NO_DEFAULT)
end
push!(ae, expr)
push!(ad, default)
ae, ad
end

# Extract the expressions of the arguments, without default values.
allargs_exprs = first.(allargs_exprs_defaults)


# Extract the names of the arguments.
allargs_syms = map(allargs_exprs) do arg
MacroTools.@match arg begin
Expand All @@ -196,28 +168,11 @@ function build_model_info(input_expr)
x_ => x
end
end

# Build named tuple expression of the argument symbols and variables of the same name.
allargs_namedtuple = to_namedtuple_expr(allargs_syms)

# Extract default values of the positional and keyword arguments.
default_syms = []
default_vals = []
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
if val !== NO_DEFAULT
push!(default_syms, sym)
push!(default_vals, val)
end
end

# Build named tuple expression of the argument symbols with default values.
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)


modelinfo = Dict(
:allargs_exprs => allargs_exprs,
:allargs_syms => allargs_syms,
:allargs_namedtuple => allargs_namedtuple,
:defaults_namedtuple => defaults_namedtuple,
:allargs_defaults => allargs_defaults,
:modeldef => modeldef,
)

Expand All @@ -233,43 +188,50 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
function generate_mainbody(mod, expr, warn)
varnames = Symbol[]
body = generate_mainbody!(mod, Symbol[], varnames, expr, warn)
return body, varnames
end

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found_internals, varnames, x, warn) = x
function generate_mainbody!(mod, found_internals, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
return generate_mainbody!(mod, found_internals, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
if warn && sym in INTERNALNAMES && sym ∉ found_internals
@warn "you are using the internal variable `$sym`"
push!(found, sym)
push!(found_internals, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
return generate_mainbody!(
mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn
)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
!isliteral(L) && push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end
Expand All @@ -278,15 +240,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
!isliteral(L) && push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...,
)
end

"""
Expand All @@ -307,26 +273,26 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe!)(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
Expand All @@ -351,26 +317,26 @@ function generate_dot_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe!)(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$inds,
__varinfo__,
)
Expand Down Expand Up @@ -413,10 +379,24 @@ function build_output(modelinfo, linenumbernode)

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]
# Extract the named tuple expression of all arguments
allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]]
allargs_wrapped = map(modelinfo[:allargs_syms], modelinfo[:allargs_defaults]) do x, d
if x ∈ modelinfo[:obsnames]
:($(DynamicPPL.Variable)($x, $d))
else
:($(DynamicPPL.Constant)($x, $d))
end
end
allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)]
allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames)

internals_newnames = [gensym(x) for x in modelinfo[:latentnames]]
internals_decls = map(internals_newnames) do name
:($name = $(DynamicPPL.Variable)(missing))
end
internals_namedtuple = to_namedtuple_expr(modelinfo[:latentnames], internals_newnames)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
Expand All @@ -427,11 +407,13 @@ function build_output(modelinfo, linenumbernode)
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
$evaluator = $(MacroTools.combinedef(evaluatordef))
$(allargs_decls...)
$(internals_decls...)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
Comment on lines 413 to 415
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
$(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple

$defaults_namedtuple,
$internals_namedtuple,
)
end

Expand Down
2 changes: 0 additions & 2 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
_getindex(x, inds::Tuple{}) = x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# assume
"""
Expand Down
Loading