Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 91 additions & 101 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,12 @@
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
const DISTMSG =
"Right-hand side of a ~ must be subtype of Distribution or a vector of " *

"Distributions."

const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

"""
isassumption(expr)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.

Let `expr` be `:(x[1])`. It is an assumption in the following cases:
1. `x` is not among the input data to the model,
2. `x` is among the input data to the model but with a value `missing`, or
3. `x` is among the input data to the model with a value other than missing,
but `x[1] === missing`.

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
# Evaluate the LHS
$expr === missing
end
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

"""
isliteral(expr)

Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise.
"""
isliteral(e) = false
Expand All @@ -47,7 +15,6 @@ isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args)

"""
check_tilde_rhs(x)

Check if the right-hand side `x` of a `~` is a `Distribution` or an array of
`Distributions`, then return `x`.
"""
Expand All @@ -63,21 +30,17 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x

"""
unwrap_right_vn(right, vn)

Return the unwrapped distribution on the right-hand side and variable name on the left-hand
side of a `~` expression such as `x ~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions.
"""
unwrap_right_vn(right, vn) = right, vn
unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name)

"""
unwrap_right_left_vns(right, left, vns)

Return the unwrapped distributions on the right-hand side and values and variable names on the
left-hand side of a `.~` expression such as `x .~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.
"""
Expand All @@ -104,6 +67,7 @@ function unwrap_right_left_vns(
return unwrap_right_left_vns(right, left, vns)
end


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

#################
# Main Compiler #
#################
Expand Down Expand Up @@ -137,9 +101,13 @@ end
function model(mod, linenumbernode, expr, warn)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
# Generate main body and find all variable symbols
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(
mod, modelinfo[:modeldef][:body], warn
)


# extract parameters and observations from that
modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms])
modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

return build_output(modelinfo, linenumbernode)
end

Expand Down Expand Up @@ -167,8 +135,8 @@ function build_model_info(input_expr)
modelinfo = Dict(
:allargs_exprs => [],
:allargs_syms => [],
:allargs_namedtuple => NamedTuple(),
:defaults_namedtuple => NamedTuple(),
# :allargs_namedtuple => NamedTuple(),
# :defaults_namedtuple => NamedTuple(),
:modeldef => modeldef,
)
return modelinfo
Expand Down Expand Up @@ -198,26 +166,26 @@ function build_model_info(input_expr)
end

# Build named tuple expression of the argument symbols and variables of the same name.
allargs_namedtuple = to_namedtuple_expr(allargs_syms)

# Extract default values of the positional and keyword arguments.
default_syms = []
default_vals = []
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
if val !== NO_DEFAULT
push!(default_syms, sym)
push!(default_vals, val)
end
end
# allargs_namedtuple = to_namedtuple_expr(allargs_syms)

# # Extract default values of the positional and keyword arguments.
# default_syms = []
# default_vals = []
# for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
# if val !== NO_DEFAULT
# push!(default_syms, sym)
# push!(default_vals, val)
# end
# end

# Build named tuple expression of the argument symbols with default values.
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)
# # Build named tuple expression of the argument symbols with default values.
# defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)

modelinfo = Dict(
:allargs_exprs => allargs_exprs,
:allargs_syms => allargs_syms,
:allargs_namedtuple => allargs_namedtuple,
:defaults_namedtuple => defaults_namedtuple,
# :allargs_namedtuple => allargs_namedtuple,
# :defaults_namedtuple => defaults_namedtuple,
:modeldef => modeldef,
)

Expand All @@ -233,43 +201,54 @@ Generate the body of the main evaluation function from expression `expr` and arg
If `warn` is true, a warning is displayed if internal variables are used in the model
definition.
"""
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
function generate_mainbody(mod, expr, warn)
varnames = Symbol[]
body = generate_mainbody!(mod, Symbol[], varnames, expr, warn)
return body, varnames
end

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
generate_mainbody!(mod, found_internals, varnames, x, warn) = x
function generate_mainbody!(mod, found_internals, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
return generate_mainbody!(mod, found_internals, newsym, warn)
end

if warn && sym in INTERNALNAMES && sym ∉ found
if warn && sym in INTERNALNAMES && sym ∉ found_internals
@warn "you are using the internal variable `$sym`"
push!(found, sym)
push!(found_internals, sym)
end

return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
return generate_mainbody!(
mod,
found_internals,
varnames,
macroexpand(mod, expr; recursive=true),
warn
Comment on lines +235 to +239
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
mod,
found_internals,
varnames,
macroexpand(mod, expr; recursive=true),
warn
mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn

)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_dot_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end
Expand All @@ -278,15 +257,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
push!(varnames, vsym(L))
return Base.remove_linenums!(
generate_tilde(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
generate_mainbody!(mod, found_internals, varnames, L, warn),
generate_mainbody!(mod, found_internals, varnames, R, warn),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
return Expr(
expr.head,
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...,

)
end

"""
Expand All @@ -307,26 +290,26 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe!)(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
Expand All @@ -351,26 +334,26 @@ function generate_dot_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn inds isassumption
@gensym vn inds isobservation
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
if $isobservation
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe!)(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$inds,
__varinfo__,
)
Expand Down Expand Up @@ -413,26 +396,33 @@ function build_output(modelinfo, linenumbernode)

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]
# Extract the named tuple expression of all arguments
allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]]
allargs_wrapped = [
x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Parameter)($x))
for x in modelinfo[:allargs_syms]
Comment on lines +402 to +403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
x modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Parameter)($x))
for x in modelinfo[:allargs_syms]
if x modelinfo[:obsnames]
:($(DynamicPPL.Observation)($x))
else
:($(DynamicPPL.Parameter)($x))
end for x in modelinfo[:allargs_syms]
]
allargs_decls = [
:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)

]
allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)]

allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames)
# modelinfo[:allargs_namedtuple] = to_namedtuple_expr(modelinfo[:allargs_syms], args_vals)

# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@gensym evaluator
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
# $(observation_checks...)
$evaluator = $(MacroTools.combinedef(evaluatordef))
$(allargs_decls...)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
$defaults_namedtuple,
)
$allargs_namedtuple)
Comment on lines 423 to +425
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(QuoteNode(modeldef[:name])),
$evaluator,
$allargs_namedtuple,
$defaults_namedtuple,
)
$allargs_namedtuple)
$(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple
)

end

return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef)))
Expand All @@ -442,7 +432,7 @@ function warn_empty(body)
if all(l -> isa(l, LineNumberNode), body.args)
@warn("Model definition seems empty, still continue.")
end
return nothing
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return
return nothing

end

"""
Expand Down
2 changes: 0 additions & 2 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
_getindex(x, inds::Tuple{}) = x

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

# assume
"""
Expand Down
Loading