-
Notifications
You must be signed in to change notification settings - Fork 37
Different approach to how observations/missings are stored in the model #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
10a1505
781a86d
8244cdb
66b228c
0a074bc
e79a0d0
852afb9
ab8ce2c
1979a64
32184f2
f070f82
1b32918
9df5eb2
8c1a1f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,40 +1,6 @@ | ||||||||||
| const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__) | ||||||||||
| const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| isassumption(expr) | ||||||||||
|
|
||||||||||
| Return an expression that can be evaluated to check if `expr` is an assumption in the | ||||||||||
| model. | ||||||||||
|
|
||||||||||
| Let `expr` be `:(x[1])`. It is an assumption in the following cases: | ||||||||||
| 1. `x` is not among the input data to the model, | ||||||||||
| 2. `x` is among the input data to the model but with a value `missing`, or | ||||||||||
| 3. `x` is among the input data to the model with a value other than missing, | ||||||||||
| but `x[1] === missing`. | ||||||||||
|
|
||||||||||
| When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. | ||||||||||
| """ | ||||||||||
| function isassumption(expr::Union{Symbol,Expr}) | ||||||||||
| vn = gensym(:vn) | ||||||||||
|
|
||||||||||
| return quote | ||||||||||
| let $vn = $(varname(expr)) | ||||||||||
| # This branch should compile nicely in all cases except for partial missing data | ||||||||||
| # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` | ||||||||||
| if !$(DynamicPPL.inargnames)($vn, __model__) || | ||||||||||
| $(DynamicPPL.inmissings)($vn, __model__) | ||||||||||
| true | ||||||||||
| else | ||||||||||
| # Evaluate the LHS | ||||||||||
| $expr === missing | ||||||||||
| end | ||||||||||
| end | ||||||||||
| end | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # failsafe: a literal is never an assumption | ||||||||||
| isassumption(expr) = :(false) | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| isliteral(expr) | ||||||||||
|
|
@@ -137,8 +103,14 @@ end | |||||||||
| function model(mod, linenumbernode, expr, warn) | ||||||||||
| modelinfo = build_model_info(expr) | ||||||||||
|
|
||||||||||
| # Generate main body | ||||||||||
| modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | ||||||||||
| # Generate main body and find all variable symbols | ||||||||||
| modelinfo[:body], modelinfo[:varnames] = generate_mainbody( | ||||||||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should |
||||||||||
| mod, modelinfo[:modeldef][:body], warn | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # extract observations from that | ||||||||||
| modelinfo[:obsnames] = modelinfo[:allargs_syms] ∩ modelinfo[:varnames] | ||||||||||
| modelinfo[:latentnames] = setdiff(modelinfo[:varnames], modelinfo[:allargs_syms]) | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, this will be incorrect for submodels and other similar non-top-level assume statements 😕 |
||||||||||
|
|
||||||||||
| return build_output(modelinfo, linenumbernode) | ||||||||||
| end | ||||||||||
|
|
@@ -167,8 +139,7 @@ function build_model_info(input_expr) | |||||||||
| modelinfo = Dict( | ||||||||||
| :allargs_exprs => [], | ||||||||||
| :allargs_syms => [], | ||||||||||
| :allargs_namedtuple => NamedTuple(), | ||||||||||
| :defaults_namedtuple => NamedTuple(), | ||||||||||
| :allargs_defaults => [], | ||||||||||
| :modeldef => modeldef, | ||||||||||
| ) | ||||||||||
| return modelinfo | ||||||||||
|
|
@@ -177,17 +148,18 @@ function build_model_info(input_expr) | |||||||||
| # Extract the positional and keyword arguments from the model definition. | ||||||||||
| allargs = vcat(modeldef[:args], modeldef[:kwargs]) | ||||||||||
|
|
||||||||||
| # Split the argument expressions and the default values. | ||||||||||
| allargs_exprs_defaults = map(allargs) do arg | ||||||||||
| MacroTools.@match arg begin | ||||||||||
| # Split the argument expressions and the default values, by unzipping allargs, taking care of | ||||||||||
| # the empty case | ||||||||||
| allargs_exprs, allargs_defaults = foldl(allargs; init=([], [])) do (ae, ad), arg | ||||||||||
| (expr, default) = MacroTools.@match arg begin | ||||||||||
| (x_ = val_) => (x, val) | ||||||||||
| x_ => (x, NO_DEFAULT) | ||||||||||
| end | ||||||||||
| push!(ae, expr) | ||||||||||
| push!(ad, default) | ||||||||||
| ae, ad | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Extract the expressions of the arguments, without default values. | ||||||||||
| allargs_exprs = first.(allargs_exprs_defaults) | ||||||||||
|
|
||||||||||
|
|
||||||||||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| # Extract the names of the arguments. | ||||||||||
| allargs_syms = map(allargs_exprs) do arg | ||||||||||
| MacroTools.@match arg begin | ||||||||||
|
|
@@ -196,28 +168,11 @@ function build_model_info(input_expr) | |||||||||
| x_ => x | ||||||||||
| end | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Build named tuple expression of the argument symbols and variables of the same name. | ||||||||||
| allargs_namedtuple = to_namedtuple_expr(allargs_syms) | ||||||||||
|
|
||||||||||
| # Extract default values of the positional and keyword arguments. | ||||||||||
| default_syms = [] | ||||||||||
| default_vals = [] | ||||||||||
| for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) | ||||||||||
| if val !== NO_DEFAULT | ||||||||||
| push!(default_syms, sym) | ||||||||||
| push!(default_vals, val) | ||||||||||
| end | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Build named tuple expression of the argument symbols with default values. | ||||||||||
| defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) | ||||||||||
|
|
||||||||||
|
|
||||||||||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| modelinfo = Dict( | ||||||||||
| :allargs_exprs => allargs_exprs, | ||||||||||
| :allargs_syms => allargs_syms, | ||||||||||
| :allargs_namedtuple => allargs_namedtuple, | ||||||||||
| :defaults_namedtuple => defaults_namedtuple, | ||||||||||
| :allargs_defaults => allargs_defaults, | ||||||||||
| :modeldef => modeldef, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
@@ -233,43 +188,50 @@ Generate the body of the main evaluation function from expression `expr` and arg | |||||||||
| If `warn` is true, a warning is displayed if internal variables are used in the model | ||||||||||
| definition. | ||||||||||
| """ | ||||||||||
| generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) | ||||||||||
| function generate_mainbody(mod, expr, warn) | ||||||||||
| varnames = Symbol[] | ||||||||||
| body = generate_mainbody!(mod, Symbol[], varnames, expr, warn) | ||||||||||
| return body, varnames | ||||||||||
| end | ||||||||||
|
|
||||||||||
| generate_mainbody!(mod, found, x, warn) = x | ||||||||||
| function generate_mainbody!(mod, found, sym::Symbol, warn) | ||||||||||
| generate_mainbody!(mod, found_internals, varnames, x, warn) = x | ||||||||||
| function generate_mainbody!(mod, found_internals, sym::Symbol, warn) | ||||||||||
| if sym in DEPRECATED_INTERNALNAMES | ||||||||||
| newsym = Symbol(:_, sym, :__) | ||||||||||
| Base.depwarn( | ||||||||||
| "internal variable `$sym` is deprecated, use `$newsym` instead.", | ||||||||||
| :generate_mainbody!, | ||||||||||
| ) | ||||||||||
| return generate_mainbody!(mod, found, newsym, warn) | ||||||||||
| return generate_mainbody!(mod, found_internals, newsym, warn) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| if warn && sym in INTERNALNAMES && sym ∉ found | ||||||||||
| if warn && sym in INTERNALNAMES && sym ∉ found_internals | ||||||||||
| @warn "you are using the internal variable `$sym`" | ||||||||||
| push!(found, sym) | ||||||||||
| push!(found_internals, sym) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| return sym | ||||||||||
| end | ||||||||||
| function generate_mainbody!(mod, found, expr::Expr, warn) | ||||||||||
| function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) | ||||||||||
| # Do not touch interpolated expressions | ||||||||||
| expr.head === :$ && return expr.args[1] | ||||||||||
|
|
||||||||||
| # If it's a macro, we expand it | ||||||||||
| if Meta.isexpr(expr, :macrocall) | ||||||||||
| return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) | ||||||||||
| return generate_mainbody!( | ||||||||||
| mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn | ||||||||||
| ) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| # Modify dotted tilde operators. | ||||||||||
| args_dottilde = getargs_dottilde(expr) | ||||||||||
| if args_dottilde !== nothing | ||||||||||
| L, R = args_dottilde | ||||||||||
| !isliteral(L) && push!(varnames, vsym(L)) | ||||||||||
| return Base.remove_linenums!( | ||||||||||
| generate_dot_tilde( | ||||||||||
| generate_mainbody!(mod, found, L, warn), | ||||||||||
| generate_mainbody!(mod, found, R, warn), | ||||||||||
| generate_mainbody!(mod, found_internals, varnames, L, warn), | ||||||||||
| generate_mainbody!(mod, found_internals, varnames, R, warn), | ||||||||||
| ), | ||||||||||
| ) | ||||||||||
| end | ||||||||||
|
|
@@ -278,15 +240,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |||||||||
| args_tilde = getargs_tilde(expr) | ||||||||||
| if args_tilde !== nothing | ||||||||||
| L, R = args_tilde | ||||||||||
| !isliteral(L) && push!(varnames, vsym(L)) | ||||||||||
| return Base.remove_linenums!( | ||||||||||
| generate_tilde( | ||||||||||
| generate_mainbody!(mod, found, L, warn), | ||||||||||
| generate_mainbody!(mod, found, R, warn), | ||||||||||
| generate_mainbody!(mod, found_internals, varnames, L, warn), | ||||||||||
| generate_mainbody!(mod, found_internals, varnames, R, warn), | ||||||||||
| ), | ||||||||||
| ) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) | ||||||||||
| return Expr( | ||||||||||
| expr.head, | ||||||||||
| map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)..., | ||||||||||
| ) | ||||||||||
| end | ||||||||||
|
|
||||||||||
| """ | ||||||||||
|
|
@@ -307,26 +273,26 @@ function generate_tilde(left, right) | |||||||||
|
|
||||||||||
| # Otherwise it is determined by the model or its value, | ||||||||||
| # if the LHS represents an observation | ||||||||||
| @gensym vn inds isassumption | ||||||||||
| @gensym vn inds isobservation | ||||||||||
| return quote | ||||||||||
| $vn = $(varname(left)) | ||||||||||
| $inds = $(vinds(left)) | ||||||||||
| $isassumption = $(DynamicPPL.isassumption(left)) | ||||||||||
| if $isassumption | ||||||||||
| $left = $(DynamicPPL.tilde_assume!)( | ||||||||||
| $isobservation = $(DynamicPPL.isobservation)($vn, __model__) | ||||||||||
| if $isobservation | ||||||||||
| $(DynamicPPL.tilde_observe!)( | ||||||||||
| __context__, | ||||||||||
| $(DynamicPPL.unwrap_right_vn)( | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $vn | ||||||||||
| )..., | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||
| $left, | ||||||||||
| $vn, | ||||||||||
| $inds, | ||||||||||
| __varinfo__, | ||||||||||
| ) | ||||||||||
| else | ||||||||||
| $(DynamicPPL.tilde_observe!)( | ||||||||||
| $left = $(DynamicPPL.tilde_assume!)( | ||||||||||
| __context__, | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||
| $left, | ||||||||||
| $vn, | ||||||||||
| $(DynamicPPL.unwrap_right_vn)( | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $vn | ||||||||||
| )..., | ||||||||||
| $inds, | ||||||||||
| __varinfo__, | ||||||||||
| ) | ||||||||||
|
|
@@ -351,26 +317,26 @@ function generate_dot_tilde(left, right) | |||||||||
|
|
||||||||||
| # Otherwise it is determined by the model or its value, | ||||||||||
| # if the LHS represents an observation | ||||||||||
| @gensym vn inds isassumption | ||||||||||
| @gensym vn inds isobservation | ||||||||||
| return quote | ||||||||||
| $vn = $(varname(left)) | ||||||||||
| $inds = $(vinds(left)) | ||||||||||
| $isassumption = $(DynamicPPL.isassumption(left)) | ||||||||||
| if $isassumption | ||||||||||
| $left .= $(DynamicPPL.dot_tilde_assume!)( | ||||||||||
| $isobservation = $(DynamicPPL.isobservation)($vn, __model__) | ||||||||||
| if $isobservation | ||||||||||
| $(DynamicPPL.dot_tilde_observe!)( | ||||||||||
| __context__, | ||||||||||
| $(DynamicPPL.unwrap_right_left_vns)( | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $left, $vn | ||||||||||
| )..., | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||
| $left, | ||||||||||
| $vn, | ||||||||||
| $inds, | ||||||||||
| __varinfo__, | ||||||||||
| ) | ||||||||||
| else | ||||||||||
| $(DynamicPPL.dot_tilde_observe!)( | ||||||||||
| $left .= $(DynamicPPL.dot_tilde_assume!)( | ||||||||||
| __context__, | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||
| $left, | ||||||||||
| $vn, | ||||||||||
| $(DynamicPPL.unwrap_right_left_vns)( | ||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $left, $vn | ||||||||||
| )..., | ||||||||||
| $inds, | ||||||||||
| __varinfo__, | ||||||||||
| ) | ||||||||||
|
|
@@ -413,10 +379,24 @@ function build_output(modelinfo, linenumbernode) | |||||||||
|
|
||||||||||
| ## Build the model function. | ||||||||||
|
|
||||||||||
| # Extract the named tuple expression of all arguments and the default values. | ||||||||||
| allargs_namedtuple = modelinfo[:allargs_namedtuple] | ||||||||||
| defaults_namedtuple = modelinfo[:defaults_namedtuple] | ||||||||||
| # Extract the named tuple expression of all arguments | ||||||||||
| allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] | ||||||||||
| allargs_wrapped = map(modelinfo[:allargs_syms], modelinfo[:allargs_defaults]) do x, d | ||||||||||
| if x ∈ modelinfo[:obsnames] | ||||||||||
| :($(DynamicPPL.Variable)($x, $d)) | ||||||||||
| else | ||||||||||
| :($(DynamicPPL.Constant)($x, $d)) | ||||||||||
| end | ||||||||||
| end | ||||||||||
| allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] | ||||||||||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) | ||||||||||
|
|
||||||||||
| internals_newnames = [gensym(x) for x in modelinfo[:latentnames]] | ||||||||||
| internals_decls = map(internals_newnames) do name | ||||||||||
| :($name = $(DynamicPPL.Variable)(missing)) | ||||||||||
| end | ||||||||||
| internals_namedtuple = to_namedtuple_expr(modelinfo[:latentnames], internals_newnames) | ||||||||||
|
|
||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
| # Update the function body of the user-specified model. | ||||||||||
| # We use a name for the anonymous evaluator that does not conflict with other variables. | ||||||||||
| modeldef = modelinfo[:modeldef] | ||||||||||
|
|
@@ -427,11 +407,13 @@ function build_output(modelinfo, linenumbernode) | |||||||||
| modeldef[:body] = MacroTools.@q begin | ||||||||||
| $(linenumbernode) | ||||||||||
| $evaluator = $(MacroTools.combinedef(evaluatordef)) | ||||||||||
| $(allargs_decls...) | ||||||||||
| $(internals_decls...) | ||||||||||
| return $(DynamicPPL.Model)( | ||||||||||
| $(QuoteNode(modeldef[:name])), | ||||||||||
| $evaluator, | ||||||||||
| $allargs_namedtuple, | ||||||||||
|
Comment on lines
413
to
415
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||
| $defaults_namedtuple, | ||||||||||
| $internals_namedtuple, | ||||||||||
| ) | ||||||||||
| end | ||||||||||
|
|
||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) | |||
| require_gradient(spl::Sampler) = false | ||||
| require_particles(spl::Sampler) = false | ||||
|
|
||||
| _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) | ||||
| _getindex(x, inds::Tuple{}) = x | ||||
|
|
||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
| # assume | ||||
| """ | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.