diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..0434438b5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -67,8 +67,7 @@ export AbstractVarInfo, vectorize, # Model Model, - getmissings, - getargnames, + getargumentnames, generated_quantities, # Samplers Sampler, diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..7b6b3184f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,40 +1,6 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__) const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) -""" - isassumption(expr) - -Return an expression that can be evaluated to check if `expr` is an assumption in the -model. - -Let `expr` be `:(x[1])`. It is an assumption in the following cases: - 1. `x` is not among the input data to the model, - 2. `x` is among the input data to the model but with a value `missing`, or - 3. `x` is among the input data to the model with a value other than missing, - but `x[1] === missing`. - -When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -""" -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - - return quote - let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - # Evaluate the LHS - $expr === missing - end - end - end -end - -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) """ isliteral(expr) @@ -137,8 +103,14 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) - # Generate main body - modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) + # Generate main body and find all variable symbols + modelinfo[:body], modelinfo[:varnames] = generate_mainbody( + mod, modelinfo[:modeldef][:body], warn + ) + + # extract observations from that + modelinfo[:obsnames] = modelinfo[:allargs_syms] ∩ modelinfo[:varnames] + modelinfo[:latentnames] = setdiff(modelinfo[:varnames], modelinfo[:allargs_syms]) return build_output(modelinfo, linenumbernode) end @@ -167,8 +139,7 @@ function build_model_info(input_expr) modelinfo = Dict( :allargs_exprs => [], :allargs_syms => [], - :allargs_namedtuple => NamedTuple(), - :defaults_namedtuple => NamedTuple(), + :allargs_defaults => [], :modeldef => modeldef, ) return modelinfo @@ -177,17 +148,18 @@ function build_model_info(input_expr) # Extract the positional and keyword arguments from the model definition. allargs = vcat(modeldef[:args], modeldef[:kwargs]) - # Split the argument expressions and the default values. - allargs_exprs_defaults = map(allargs) do arg - MacroTools.@match arg begin + # Split the argument expressions and the default values, by unzipping allargs, taking care of + # the empty case + allargs_exprs, allargs_defaults = foldl(allargs; init=([], [])) do (ae, ad), arg + (expr, default) = MacroTools.@match arg begin (x_ = val_) => (x, val) x_ => (x, NO_DEFAULT) end + push!(ae, expr) + push!(ad, default) + ae, ad end - - # Extract the expressions of the arguments, without default values. - allargs_exprs = first.(allargs_exprs_defaults) - + # Extract the names of the arguments. allargs_syms = map(allargs_exprs) do arg MacroTools.@match arg begin @@ -196,28 +168,11 @@ function build_model_info(input_expr) x_ => x end end - - # Build named tuple expression of the argument symbols and variables of the same name. - allargs_namedtuple = to_namedtuple_expr(allargs_syms) - - # Extract default values of the positional and keyword arguments. - default_syms = [] - default_vals = [] - for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) - if val !== NO_DEFAULT - push!(default_syms, sym) - push!(default_vals, val) - end - end - - # Build named tuple expression of the argument symbols with default values. - defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) - + modelinfo = Dict( :allargs_exprs => allargs_exprs, :allargs_syms => allargs_syms, - :allargs_namedtuple => allargs_namedtuple, - :defaults_namedtuple => defaults_namedtuple, + :allargs_defaults => allargs_defaults, :modeldef => modeldef, ) @@ -233,43 +188,50 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +function generate_mainbody(mod, expr, warn) + varnames = Symbol[] + body = generate_mainbody!(mod, Symbol[], varnames, expr, warn) + return body, varnames +end -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found_internals, varnames, x, warn) = x +function generate_mainbody!(mod, found_internals, sym::Symbol, warn) if sym in DEPRECATED_INTERNALNAMES newsym = Symbol(:_, sym, :__) Base.depwarn( "internal variable `$sym` is deprecated, use `$newsym` instead.", :generate_mainbody!, ) - return generate_mainbody!(mod, found, newsym, warn) + return generate_mainbody!(mod, found_internals, newsym, warn) end - if warn && sym in INTERNALNAMES && sym ∉ found + if warn && sym in INTERNALNAMES && sym ∉ found_internals @warn "you are using the internal variable `$sym`" - push!(found, sym) + push!(found_internals, sym) end return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody!( + mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn + ) end # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde + !isliteral(L) && push!(varnames, vsym(L)) return Base.remove_linenums!( generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end @@ -278,15 +240,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde + !isliteral(L) && push!(varnames, vsym(L)) return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)..., + ) end """ @@ -307,26 +273,26 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.tilde_observe!)( + $left = $(DynamicPPL.tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -351,26 +317,26 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.dot_tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $left, $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe!)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) @@ -413,10 +379,24 @@ function build_output(modelinfo, linenumbernode) ## Build the model function. - # Extract the named tuple expression of all arguments and the default values. - allargs_namedtuple = modelinfo[:allargs_namedtuple] - defaults_namedtuple = modelinfo[:defaults_namedtuple] + # Extract the named tuple expression of all arguments + allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] + allargs_wrapped = map(modelinfo[:allargs_syms], modelinfo[:allargs_defaults]) do x, d + if x ∈ modelinfo[:obsnames] + :($(DynamicPPL.Variable)($x, $d)) + else + :($(DynamicPPL.Constant)($x, $d)) + end + end + allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] + allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) + internals_newnames = [gensym(x) for x in modelinfo[:latentnames]] + internals_decls = map(internals_newnames) do name + :($name = $(DynamicPPL.Variable)(missing)) + end + internals_namedtuple = to_namedtuple_expr(modelinfo[:latentnames], internals_newnames) + # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] @@ -427,11 +407,13 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) + $(allargs_decls...) + $(internals_decls...) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, - $defaults_namedtuple, + $internals_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..0ab8b4616 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x # assume """ diff --git a/src/model.jl b/src/model.jl index 9ec047a44..976e54cf8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,78 +1,62 @@ -""" - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} - name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} - end -A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing -arguments `missings`. +""" + abstract type Argument{T,Tdefault} end -Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. +Parametric wrapper type for model arguments. +""" +abstract type Argument{T,Tdefault} end -An argument with a type of `Missing` will be in `missings` by default. However, in -non-traditional use-cases `missings` can be defined differently. All variables in `missings` -are treated as random variables rather than observations. +struct Variable{T,Tdefault} <: Argument{T,Tdefault} + value::T + default::Tdefault +end -The default arguments are used internally when constructing instances of the same model with -different arguments. +Variable(x) = Variable(x, NO_DEFAULT) -# Examples +struct Constant{T,Tdefault} <: Argument{T,Tdefault} + value::T + default::Tdefault +end -```julia -julia> Model(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple()) +Constant(x) = Constant(x, NO_DEFAULT) -julia> Model(f, (x = 1.0, y = 2.0), (x = 42,)) -Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) -julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings -Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) -``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram - name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} - - """ - Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) - - Create a model of name `name` with evaluation function `f` and missing arguments - overwritten by `missings`. - """ - function Model{missings}( - name::Symbol, - f::F, - args::NamedTuple{argnames,Targs}, - defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults - ) + struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram + name::Symbol + evaluator::F + arguments::NamedTuple{argumentnames,Targs} end -end +A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. """ - Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()]) +struct Model{F, argumentnames, Targs, internalnames, Tinternals} <: AbstractProbabilisticProgram + name::Symbol + # code::Expr + evaluator::F + arguments::NamedTuple{argumentnames,Targs} + internal_variables::NamedTuple{internalnames,Tinternals} +end -Create a model of name `name` with evaluation function `f` and missing arguments deduced -from `args`. +function Base.show(io::IO, ::MIME"text/plain", model::Model) + println(io, "Model ", model.name, " given") + print(io, " constants: ") + join(io, getconstants(model), ", ") + println(io) + print(io, " observed variables: ") + join(io, getobservedvariables(model), ", ") + println(io) + print(io, " latent variables: ") + join(io, getlatentvariables(model), ", ") + return nothing +end -Default arguments `defaults` are used internally when constructing instances of the same -model with different arguments. -""" -@generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() -) where {F,argnames,Targs} - missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) +function Base.show(io::IO, model::Model) + println(io, "$(model.name)$(getarguments(model))") + return nothing end + """ (model::Model)([rng, varinfo, sampler, context]) @@ -153,26 +137,112 @@ end Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ -@generated function _evaluate( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) +function _evaluate(model::Model, varinfo, context) + matched_args = map(arg -> matchingvalue(context, varinfo, arg), getarguments(model)) + return model.evaluator(model, varinfo, context, matched_args...) end + """ - getargnames(model::Model) + getargumentnames(model::Model, [::Type{T}]) -Get a tuple of the argument names of the `model`. +Return a tuple of the argument names of the `model`. The second argument can be used to filter +the types of arguments (constant, variable, default) by passing an `Argument` subtype. """ -getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames +getargumentnames(model::Model{<:Any,argnames}) where {argnames} = argnames +@generated function getargumentnames( + model::Model{<:Any,argnames,Targs}, + ::Type{T} +) where {argnames,Targs,T} + return _getargumentnames(argnames, Targs, T) +end +function _getargumentnames(argnames, Targs, ::Type{T}) where {T} + return Tuple([n for (n, Targ) in zip(argnames, Targs.parameters) if Targ <: T]) +end + +Base.@deprecate getargnames(model) getargumentnames(model) + +""" + getarguments(model::Model, [::Type{T}]) + +Return a `NamedTuple` of the constants passed to `model`. The second argument can be used to filter +the types of arguments (constant, variable, default) by passing an `Argument` subtype. +""" +getarguments(model::Model) = map(arg -> arg.value, model.arguments) +@generated function getarguments( + model::Model{<:Any,argnames,Targs}, + ::Type{T} +) where {argnames,Targs,T} + filtered_argnames = _getargumentnames(argnames, Targs, T) + values = [:(model.arguments.$arg.value) for arg in filtered_argnames] + return :(NamedTuple{$filtered_argnames}(($(values...),))) +end + +hasargument(model::Model, argname::Symbol) = isdefined(model.arguments, argname) +getargument(model::Model, argname::Symbol) = getproperty(model.arguments, argname).value + +hasdefault(model::Model, argname::Symbol) = getdefault(model, argname) !== NO_DEFAULT +getdefault(model::Model, argname::Symbol) = getproperty(model.arguments, argname).default + """ - getmissings(model::Model) + isobservation(vn, model) -Get a tuple of the names of the missing arguments of the `model`. +Check whether the value of the expression `vn` is a real observation in the `model`. + +A variable is an observation if it is among the arguments data of the model, and the corresponding +observation value is not `missing` (e.g., it could happen that the arguments contain `x = +[missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) """ -getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings +function isobservation(vn::VarName{s}, model::Model{<:Any,argnames}) where {s,argnames} + return (s in argnames) && isobservation(vn, getproperty(model.arguments, s)) +end +isobservation(::VarName, ::Constant) = false +isobservation(vn::VarName, obs::Variable{Missing}) = false +isobservation(vn::VarName, obs::Variable) = !ismissing(_getindex(obs.value, vn.indexing)) + +function getvariables_separated(model::Model) + observed_variables = VarName[] + latent_variables = VarName[] + + # separate the argument variables into observed and latend + for (var, value) in pairs(getarguments(model, Variable)) + if value isa AbstractArray + all_indices = CartesianIndices(value) + missing_indices = filter(ix -> ismissing(value[ix]), all_indices) + if isempty(missing_indices) + # all indexed given -- full variable observed + push!(observed_variables, VarName{var}()) + else + # mixed case -- indexed variables in both categories + for ix in all_indices + complete_name = VarName{var}((Tuple(ix),)) + if ix in missing_indices + push!(latent_variables, complete_name) + else + push!(observed_variables, complete_name) + end + end + end + else + complete_name = VarName{var}() + if ismissing(value) + push!(latent_variables, complete_name) + else + push!(observed_variables, complete_name) + end + end + end + + # add purely internal variables as latent + append!(latent_variables, (VarName{var}() for var in keys(model.internal_variables))) + + return (observed_variables, latent_variables) +end + +getobservedvariables(model::Model) = getvariables_separated(model)[1] +getlatentvariables(model::Model) = getvariables_separated(model)[2] +getconstants(model::Model) = VarName[VarName{c}() for c in getargumentnames(model, Constant)] """ nameof(model::Model) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..f25f19d83 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -1,3 +1,12 @@ +function _isdefault(argnames, TArgs, arg) + ix = findfirst(==(arg), argnames) + if !ismissing(ix) + return !(TArgs.parameters[ix] <: Argument{<:Any,NoDefault}) + else + return false + end +end + macro logprob_str(str) expr1, expr2 = get_exprs(str) return :(logprob($(esc(expr1)), $(esc(expr2)))) @@ -53,11 +62,11 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names else vi = nothing end - defaults = model.defaults - @assert all(getargnames(model)) do arg - isdefined(ntl, arg) || - isdefined(ntr, arg) || - isdefined(defaults, arg) && getfield(defaults, arg) !== missing + @assert all(getargumentnames(model)) do arg + isdefined(ntl, arg) || isdefined(ntr, arg) || + hasargument(model, arg) && + hasdefault(model, arg) && + !ismissing(getdefault(model, arg)) end return Val(:likelihood), model, vi else @@ -78,9 +87,8 @@ end function probtype( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} - defaults = model.defaults + model::Model{_F,argnames}, +) where {leftnames,rightnames,argnames,_F} prior_rhs = all( n -> n in (:model, :varinfo) || n in argnames && getfield(right, n) !== missing, rightnames, @@ -90,8 +98,8 @@ function probtype( return getfield(left, arg) elseif arg in rightnames return getfield(right, arg) - elseif arg in defaultnames - return getfield(defaults, arg) + elseif hasargument(model, arg) && hasdefault(model, arg) + return getdefault(model, arg) else return nothing end @@ -153,33 +161,31 @@ end @generated function make_prior_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} + model::Model{_F,argnames,TArgs}, +) where {leftnames,rightnames,argnames,TArgs,_F} argvals = [] - missings = [] warnings = [] for argname in argnames if argname in leftnames - push!(argvals, :(deepcopy(left.$argname))) - push!(missings, argname) + push!(argvals, :(Constant(deepcopy(left.$argname)))) elseif argname in rightnames - push!(argvals, :(right.$argname)) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) + push!(argvals, :(Variable(right.$argname))) + elseif _isdefault(argnames, TArgs, argname) + push!(argvals, :(Variable(model.arguments.$argname.default))) else push!(warnings, :(@warn($(warn_msg(argname))))) - push!(argvals, :(nothing)) + push!(argvals, :(Variable(nothing))) end end # `args` is inserted as properly typed NamedTuple expression; - # `missings` is splatted into a tuple at compile time and inserted as literal return quote $(warnings...) - Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - ) + Model(model.name, + model.evaluator, + $(to_namedtuple_expr(argnames, argvals)), + model.internal_variables) end end @@ -213,19 +219,17 @@ end @generated function make_likelihood_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} + model::Model{_F,argnames,TArgs}, +) where {leftnames,rightnames,argnames,TArgs,_F} argvals = [] - missings = [] - + for argname in argnames if argname in leftnames - push!(argvals, :(left.$argname)) + push!(argvals, :(Variable(left.$argname))) elseif argname in rightnames - push!(argvals, :(right.$argname)) - push!(missings, argname) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) + push!(argvals, :(Constant(right.$argname))) + elseif _isdefault(argnames, TArgs, argname) + push!(argvals, :(Variable(model.arguments.$argname.default))) else throw( "This point should not be reached. Please open an issue in the DynamicPPL.jl repository.", @@ -235,7 +239,9 @@ end # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal - return :(Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - )) + return :(Model( + model.name, + model.evaluator, + $(to_namedtuple_expr(argnames, argvals)), + model.internal_variables)) end diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..da66c7ddd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,3 +169,6 @@ end ####################### collectmaybe(x) = x collectmaybe(x::Base.AbstractSet) = collect(x) + +_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple{}) = x diff --git a/src/varname.jl b/src/varname.jl index 343bb0da8..59e051ad8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -15,30 +15,3 @@ function subsumes_string(u::String, v::String, u_indexing=u * "[") return u == v || startswith(v, u_indexing) end -""" - inargnames(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is an argument of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inargnames(::VarName{s}, ::Model{_F,argnames}) where {s,argnames,_F} - return s in argnames -end - -""" - inmissings(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is a statically declared unobserved variable -of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inmissings( - ::VarName{s}, ::Model{_F,_a,_T,missings} -) where {s,missings,_F,_a,_T} - return s in missings -end - -# HACK: Type-piracy. Is this really the way to go? -AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 74fb88d70..c4c1d7403 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -110,7 +110,8 @@ const gdemo_models = ( continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + args = DynamicPPL.getarguments(m) + loglikelihood = if length(keys(lls)) == 1 && length(args.x) == 1 # Only have one observation, so we need to double it # for comparison with other models. 2 * sum(lls[first(keys(lls))])