-
Notifications
You must be signed in to change notification settings - Fork 37
Refactor model for the AbstractPPL way #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,44 +1,12 @@ | ||||||||||||||||||||||
| const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " * | ||||||||||||||||||||||
| "Distributions." | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__) | ||||||||||||||||||||||
| const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| isassumption(expr) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Return an expression that can be evaluated to check if `expr` is an assumption in the | ||||||||||||||||||||||
| model. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Let `expr` be `:(x[1])`. It is an assumption in the following cases: | ||||||||||||||||||||||
| 1. `x` is not among the input data to the model, | ||||||||||||||||||||||
| 2. `x` is among the input data to the model but with a value `missing`, or | ||||||||||||||||||||||
| 3. `x` is among the input data to the model with a value other than missing, | ||||||||||||||||||||||
| but `x[1] === missing`. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| function isassumption(expr::Union{Symbol,Expr}) | ||||||||||||||||||||||
| vn = gensym(:vn) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return quote | ||||||||||||||||||||||
| let $vn = $(varname(expr)) | ||||||||||||||||||||||
| # This branch should compile nicely in all cases except for partial missing data | ||||||||||||||||||||||
| # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` | ||||||||||||||||||||||
| if !$(DynamicPPL.inargnames)($vn, __model__) || | ||||||||||||||||||||||
| $(DynamicPPL.inmissings)($vn, __model__) | ||||||||||||||||||||||
| true | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| # Evaluate the LHS | ||||||||||||||||||||||
| $expr === missing | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # failsafe: a literal is never an assumption | ||||||||||||||||||||||
| isassumption(expr) = :(false) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| isliteral(expr) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| isliteral(e) = false | ||||||||||||||||||||||
|
|
@@ -47,7 +15,6 @@ isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| check_tilde_rhs(x) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Check if the right-hand side `x` of a `~` is a `Distribution` or an array of | ||||||||||||||||||||||
| `Distributions`, then return `x`. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
@@ -63,21 +30,17 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| unwrap_right_vn(right, vn) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Return the unwrapped distribution on the right-hand side and variable name on the left-hand | ||||||||||||||||||||||
| side of a `~` expression such as `x ~ Normal()`. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| This is used mainly to unwrap `NamedDist` distributions. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| unwrap_right_vn(right, vn) = right, vn | ||||||||||||||||||||||
| unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| unwrap_right_left_vns(right, left, vns) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Return the unwrapped distributions on the right-hand side and values and variable names on the | ||||||||||||||||||||||
| left-hand side of a `.~` expression such as `x .~ Normal()`. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the | ||||||||||||||||||||||
| variables. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
@@ -104,6 +67,7 @@ function unwrap_right_left_vns( | |||||||||||||||||||||
| return unwrap_right_left_vns(right, left, vns) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| ################# | ||||||||||||||||||||||
| # Main Compiler # | ||||||||||||||||||||||
| ################# | ||||||||||||||||||||||
|
|
@@ -137,9 +101,13 @@ end | |||||||||||||||||||||
| function model(mod, linenumbernode, expr, warn) | ||||||||||||||||||||||
| modelinfo = build_model_info(expr) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Generate main body | ||||||||||||||||||||||
| modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | ||||||||||||||||||||||
| # Generate main body and find all variable symbols | ||||||||||||||||||||||
| modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # extract parameters and observations from that | ||||||||||||||||||||||
| modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms]) | ||||||||||||||||||||||
| modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| return build_output(modelinfo, linenumbernode) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -167,8 +135,8 @@ function build_model_info(input_expr) | |||||||||||||||||||||
| modelinfo = Dict( | ||||||||||||||||||||||
| :allargs_exprs => [], | ||||||||||||||||||||||
| :allargs_syms => [], | ||||||||||||||||||||||
| :allargs_namedtuple => NamedTuple(), | ||||||||||||||||||||||
| :defaults_namedtuple => NamedTuple(), | ||||||||||||||||||||||
| # :allargs_namedtuple => NamedTuple(), | ||||||||||||||||||||||
| # :defaults_namedtuple => NamedTuple(), | ||||||||||||||||||||||
| :modeldef => modeldef, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return modelinfo | ||||||||||||||||||||||
|
|
@@ -198,26 +166,26 @@ function build_model_info(input_expr) | |||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build named tuple expression of the argument symbols and variables of the same name. | ||||||||||||||||||||||
| allargs_namedtuple = to_namedtuple_expr(allargs_syms) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Extract default values of the positional and keyword arguments. | ||||||||||||||||||||||
| default_syms = [] | ||||||||||||||||||||||
| default_vals = [] | ||||||||||||||||||||||
| for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) | ||||||||||||||||||||||
| if val !== NO_DEFAULT | ||||||||||||||||||||||
| push!(default_syms, sym) | ||||||||||||||||||||||
| push!(default_vals, val) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| # allargs_namedtuple = to_namedtuple_expr(allargs_syms) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # # Extract default values of the positional and keyword arguments. | ||||||||||||||||||||||
| # default_syms = [] | ||||||||||||||||||||||
| # default_vals = [] | ||||||||||||||||||||||
| # for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) | ||||||||||||||||||||||
| # if val !== NO_DEFAULT | ||||||||||||||||||||||
| # push!(default_syms, sym) | ||||||||||||||||||||||
| # push!(default_vals, val) | ||||||||||||||||||||||
| # end | ||||||||||||||||||||||
| # end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Build named tuple expression of the argument symbols with default values. | ||||||||||||||||||||||
| defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) | ||||||||||||||||||||||
| # # Build named tuple expression of the argument symbols with default values. | ||||||||||||||||||||||
| # defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| modelinfo = Dict( | ||||||||||||||||||||||
| :allargs_exprs => allargs_exprs, | ||||||||||||||||||||||
| :allargs_syms => allargs_syms, | ||||||||||||||||||||||
| :allargs_namedtuple => allargs_namedtuple, | ||||||||||||||||||||||
| :defaults_namedtuple => defaults_namedtuple, | ||||||||||||||||||||||
| # :allargs_namedtuple => allargs_namedtuple, | ||||||||||||||||||||||
| # :defaults_namedtuple => defaults_namedtuple, | ||||||||||||||||||||||
| :modeldef => modeldef, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -233,43 +201,54 @@ Generate the body of the main evaluation function from expression `expr` and arg | |||||||||||||||||||||
| If `warn` is true, a warning is displayed if internal variables are used in the model | ||||||||||||||||||||||
| definition. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) | ||||||||||||||||||||||
| function generate_mainbody(mod, expr, warn) | ||||||||||||||||||||||
| varnames = Symbol[] | ||||||||||||||||||||||
| body = generate_mainbody!(mod, Symbol[], varnames, expr, warn) | ||||||||||||||||||||||
| return body, varnames | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| generate_mainbody!(mod, found, x, warn) = x | ||||||||||||||||||||||
| function generate_mainbody!(mod, found, sym::Symbol, warn) | ||||||||||||||||||||||
| generate_mainbody!(mod, found_internals, varnames, x, warn) = x | ||||||||||||||||||||||
| function generate_mainbody!(mod, found_internals, sym::Symbol, warn) | ||||||||||||||||||||||
| if sym in DEPRECATED_INTERNALNAMES | ||||||||||||||||||||||
| newsym = Symbol(:_, sym, :__) | ||||||||||||||||||||||
| Base.depwarn( | ||||||||||||||||||||||
| "internal variable `$sym` is deprecated, use `$newsym` instead.", | ||||||||||||||||||||||
| :generate_mainbody!, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return generate_mainbody!(mod, found, newsym, warn) | ||||||||||||||||||||||
| return generate_mainbody!(mod, found_internals, newsym, warn) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if warn && sym in INTERNALNAMES && sym ∉ found | ||||||||||||||||||||||
| if warn && sym in INTERNALNAMES && sym ∉ found_internals | ||||||||||||||||||||||
| @warn "you are using the internal variable `$sym`" | ||||||||||||||||||||||
| push!(found, sym) | ||||||||||||||||||||||
| push!(found_internals, sym) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return sym | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| function generate_mainbody!(mod, found, expr::Expr, warn) | ||||||||||||||||||||||
| function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) | ||||||||||||||||||||||
| # Do not touch interpolated expressions | ||||||||||||||||||||||
| expr.head === :$ && return expr.args[1] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # If it's a macro, we expand it | ||||||||||||||||||||||
| if Meta.isexpr(expr, :macrocall) | ||||||||||||||||||||||
| return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) | ||||||||||||||||||||||
| return generate_mainbody!( | ||||||||||||||||||||||
| mod, | ||||||||||||||||||||||
| found_internals, | ||||||||||||||||||||||
| varnames, | ||||||||||||||||||||||
| macroexpand(mod, expr; recursive=true), | ||||||||||||||||||||||
| warn | ||||||||||||||||||||||
|
Comment on lines
+235
to
+239
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Modify dotted tilde operators. | ||||||||||||||||||||||
| args_dottilde = getargs_dottilde(expr) | ||||||||||||||||||||||
| if args_dottilde !== nothing | ||||||||||||||||||||||
| L, R = args_dottilde | ||||||||||||||||||||||
| push!(varnames, vsym(L)) | ||||||||||||||||||||||
| return Base.remove_linenums!( | ||||||||||||||||||||||
| generate_dot_tilde( | ||||||||||||||||||||||
| generate_mainbody!(mod, found, L, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found, R, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found_internals, varnames, L, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found_internals, varnames, R, warn), | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
@@ -278,15 +257,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) | |||||||||||||||||||||
| args_tilde = getargs_tilde(expr) | ||||||||||||||||||||||
| if args_tilde !== nothing | ||||||||||||||||||||||
| L, R = args_tilde | ||||||||||||||||||||||
| push!(varnames, vsym(L)) | ||||||||||||||||||||||
| return Base.remove_linenums!( | ||||||||||||||||||||||
| generate_tilde( | ||||||||||||||||||||||
| generate_mainbody!(mod, found, L, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found, R, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found_internals, varnames, L, warn), | ||||||||||||||||||||||
| generate_mainbody!(mod, found_internals, varnames, R, warn), | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) | ||||||||||||||||||||||
| return Expr( | ||||||||||||||||||||||
| expr.head, | ||||||||||||||||||||||
| map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)... | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
@@ -307,26 +290,26 @@ function generate_tilde(left, right) | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Otherwise it is determined by the model or its value, | ||||||||||||||||||||||
| # if the LHS represents an observation | ||||||||||||||||||||||
| @gensym vn inds isassumption | ||||||||||||||||||||||
| @gensym vn inds isobservation | ||||||||||||||||||||||
| return quote | ||||||||||||||||||||||
| $vn = $(varname(left)) | ||||||||||||||||||||||
| $inds = $(vinds(left)) | ||||||||||||||||||||||
| $isassumption = $(DynamicPPL.isassumption(left)) | ||||||||||||||||||||||
| if $isassumption | ||||||||||||||||||||||
| $left = $(DynamicPPL.tilde_assume!)( | ||||||||||||||||||||||
| $isobservation = $(DynamicPPL.isobservation)($vn, __model__) | ||||||||||||||||||||||
| if $isobservation | ||||||||||||||||||||||
| $(DynamicPPL.tilde_observe!)( | ||||||||||||||||||||||
| __context__, | ||||||||||||||||||||||
| $(DynamicPPL.unwrap_right_vn)( | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $vn | ||||||||||||||||||||||
| )..., | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||||||||||||||
| $left, | ||||||||||||||||||||||
| $vn, | ||||||||||||||||||||||
| $inds, | ||||||||||||||||||||||
| __varinfo__, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| $(DynamicPPL.tilde_observe!)( | ||||||||||||||||||||||
| $left = $(DynamicPPL.tilde_assume!)( | ||||||||||||||||||||||
| __context__, | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||||||||||||||
| $left, | ||||||||||||||||||||||
| $vn, | ||||||||||||||||||||||
| $(DynamicPPL.unwrap_right_vn)( | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $vn | ||||||||||||||||||||||
| )..., | ||||||||||||||||||||||
| $inds, | ||||||||||||||||||||||
| __varinfo__, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
@@ -351,26 +334,26 @@ function generate_dot_tilde(left, right) | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Otherwise it is determined by the model or its value, | ||||||||||||||||||||||
| # if the LHS represents an observation | ||||||||||||||||||||||
| @gensym vn inds isassumption | ||||||||||||||||||||||
| @gensym vn inds isobservation | ||||||||||||||||||||||
| return quote | ||||||||||||||||||||||
| $vn = $(varname(left)) | ||||||||||||||||||||||
| $inds = $(vinds(left)) | ||||||||||||||||||||||
| $isassumption = $(DynamicPPL.isassumption(left)) | ||||||||||||||||||||||
| if $isassumption | ||||||||||||||||||||||
| $left .= $(DynamicPPL.dot_tilde_assume!)( | ||||||||||||||||||||||
| $isobservation = $(DynamicPPL.isobservation)($vn, __model__) | ||||||||||||||||||||||
| if $isobservation | ||||||||||||||||||||||
| $(DynamicPPL.dot_tilde_observe!)( | ||||||||||||||||||||||
| __context__, | ||||||||||||||||||||||
| $(DynamicPPL.unwrap_right_left_vns)( | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $left, $vn | ||||||||||||||||||||||
| )..., | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||||||||||||||
| $left, | ||||||||||||||||||||||
| $vn, | ||||||||||||||||||||||
| $inds, | ||||||||||||||||||||||
| __varinfo__, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| else | ||||||||||||||||||||||
| $(DynamicPPL.dot_tilde_observe!)( | ||||||||||||||||||||||
| $left .= $(DynamicPPL.dot_tilde_assume!)( | ||||||||||||||||||||||
| __context__, | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), | ||||||||||||||||||||||
| $left, | ||||||||||||||||||||||
| $vn, | ||||||||||||||||||||||
| $(DynamicPPL.unwrap_right_left_vns)( | ||||||||||||||||||||||
| $(DynamicPPL.check_tilde_rhs)($right), $left, $vn | ||||||||||||||||||||||
| )..., | ||||||||||||||||||||||
| $inds, | ||||||||||||||||||||||
| __varinfo__, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
@@ -413,26 +396,33 @@ function build_output(modelinfo, linenumbernode) | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| ## Build the model function. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Extract the named tuple expression of all arguments and the default values. | ||||||||||||||||||||||
| allargs_namedtuple = modelinfo[:allargs_namedtuple] | ||||||||||||||||||||||
| defaults_namedtuple = modelinfo[:defaults_namedtuple] | ||||||||||||||||||||||
| # Extract the named tuple expression of all arguments | ||||||||||||||||||||||
| allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] | ||||||||||||||||||||||
| allargs_wrapped = [ | ||||||||||||||||||||||
| x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Parameter)($x)) | ||||||||||||||||||||||
| for x in modelinfo[:allargs_syms] | ||||||||||||||||||||||
|
Comment on lines
+402
to
+403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| ] | ||||||||||||||||||||||
| allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) | ||||||||||||||||||||||
| # modelinfo[:allargs_namedtuple] = to_namedtuple_expr(modelinfo[:allargs_syms], args_vals) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| # Update the function body of the user-specified model. | ||||||||||||||||||||||
| # We use a name for the anonymous evaluator that does not conflict with other variables. | ||||||||||||||||||||||
| modeldef = modelinfo[:modeldef] | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| @gensym evaluator | ||||||||||||||||||||||
| # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure | ||||||||||||||||||||||
| # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` | ||||||||||||||||||||||
| # to the call site | ||||||||||||||||||||||
| modeldef[:body] = MacroTools.@q begin | ||||||||||||||||||||||
| $(linenumbernode) | ||||||||||||||||||||||
| # $(observation_checks...) | ||||||||||||||||||||||
| $evaluator = $(MacroTools.combinedef(evaluatordef)) | ||||||||||||||||||||||
| $(allargs_decls...) | ||||||||||||||||||||||
| return $(DynamicPPL.Model)( | ||||||||||||||||||||||
| $(QuoteNode(modeldef[:name])), | ||||||||||||||||||||||
| $evaluator, | ||||||||||||||||||||||
| $allargs_namedtuple, | ||||||||||||||||||||||
| $defaults_namedtuple, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| $allargs_namedtuple) | ||||||||||||||||||||||
|
Comment on lines
423
to
+425
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef))) | ||||||||||||||||||||||
|
|
@@ -442,7 +432,7 @@ function warn_empty(body) | |||||||||||||||||||||
| if all(l -> isa(l, LineNumberNode), body.args) | ||||||||||||||||||||||
| @warn("Model definition seems empty, still continue.") | ||||||||||||||||||||||
| end | ||||||||||||||||||||||
| return nothing | ||||||||||||||||||||||
| return | ||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
| end | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| """ | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) | |||
| require_gradient(spl::Sampler) = false | ||||
| require_particles(spl::Sampler) = false | ||||
|
|
||||
| _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) | ||||
| _getindex(x, inds::Tuple{}) = x | ||||
|
|
||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||
| # assume | ||||
| """ | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶