diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..8c1b58113 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,44 +1,12 @@ +const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " * + "Distributions." + const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__) const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) -""" - isassumption(expr) - -Return an expression that can be evaluated to check if `expr` is an assumption in the -model. - -Let `expr` be `:(x[1])`. It is an assumption in the following cases: - 1. `x` is not among the input data to the model, - 2. `x` is among the input data to the model but with a value `missing`, or - 3. `x` is among the input data to the model with a value other than missing, - but `x[1] === missing`. - -When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -""" -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - - return quote - let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - # Evaluate the LHS - $expr === missing - end - end - end -end - -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) """ isliteral(expr) - Return `true` if `expr` is a literal, e.g. `1.0` or `[1.0, ]`, and `false` otherwise. """ isliteral(e) = false @@ -47,7 +15,6 @@ isliteral(e::Expr) = !isempty(e.args) && all(isliteral, e.args) """ check_tilde_rhs(x) - Check if the right-hand side `x` of a `~` is a `Distribution` or an array of `Distributions`, then return `x`. """ @@ -63,10 +30,8 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x """ unwrap_right_vn(right, vn) - Return the unwrapped distribution on the right-hand side and variable name on the left-hand side of a `~` expression such as `x ~ Normal()`. - This is used mainly to unwrap `NamedDist` distributions. """ unwrap_right_vn(right, vn) = right, vn @@ -74,10 +39,8 @@ unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) """ unwrap_right_left_vns(right, left, vns) - Return the unwrapped distributions on the right-hand side and values and variable names on the left-hand side of a `.~` expression such as `x .~ Normal()`. - This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the variables. """ @@ -104,6 +67,7 @@ function unwrap_right_left_vns( return unwrap_right_left_vns(right, left, vns) end + ################# # Main Compiler # ################# @@ -137,9 +101,13 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) - # Generate main body - modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) + # Generate main body and find all variable symbols + modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) + # extract parameters and observations from that + modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms]) + modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames]) + return build_output(modelinfo, linenumbernode) end @@ -167,8 +135,8 @@ function build_model_info(input_expr) modelinfo = Dict( :allargs_exprs => [], :allargs_syms => [], - :allargs_namedtuple => NamedTuple(), - :defaults_namedtuple => NamedTuple(), + # :allargs_namedtuple => NamedTuple(), + # :defaults_namedtuple => NamedTuple(), :modeldef => modeldef, ) return modelinfo @@ -198,26 +166,26 @@ function build_model_info(input_expr) end # Build named tuple expression of the argument symbols and variables of the same name. - allargs_namedtuple = to_namedtuple_expr(allargs_syms) - - # Extract default values of the positional and keyword arguments. - default_syms = [] - default_vals = [] - for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) - if val !== NO_DEFAULT - push!(default_syms, sym) - push!(default_vals, val) - end - end + # allargs_namedtuple = to_namedtuple_expr(allargs_syms) + + # # Extract default values of the positional and keyword arguments. + # default_syms = [] + # default_vals = [] + # for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) + # if val !== NO_DEFAULT + # push!(default_syms, sym) + # push!(default_vals, val) + # end + # end - # Build named tuple expression of the argument symbols with default values. - defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) + # # Build named tuple expression of the argument symbols with default values. + # defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) modelinfo = Dict( :allargs_exprs => allargs_exprs, :allargs_syms => allargs_syms, - :allargs_namedtuple => allargs_namedtuple, - :defaults_namedtuple => defaults_namedtuple, + # :allargs_namedtuple => allargs_namedtuple, + # :defaults_namedtuple => defaults_namedtuple, :modeldef => modeldef, ) @@ -233,43 +201,54 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +function generate_mainbody(mod, expr, warn) + varnames = Symbol[] + body = generate_mainbody!(mod, Symbol[], varnames, expr, warn) + return body, varnames +end -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found_internals, varnames, x, warn) = x +function generate_mainbody!(mod, found_internals, sym::Symbol, warn) if sym in DEPRECATED_INTERNALNAMES newsym = Symbol(:_, sym, :__) Base.depwarn( "internal variable `$sym` is deprecated, use `$newsym` instead.", :generate_mainbody!, ) - return generate_mainbody!(mod, found, newsym, warn) + return generate_mainbody!(mod, found_internals, newsym, warn) end - if warn && sym in INTERNALNAMES && sym ∉ found + if warn && sym in INTERNALNAMES && sym ∉ found_internals @warn "you are using the internal variable `$sym`" - push!(found, sym) + push!(found_internals, sym) end return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody!( + mod, + found_internals, + varnames, + macroexpand(mod, expr; recursive=true), + warn + ) end # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde + push!(varnames, vsym(L)) return Base.remove_linenums!( generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end @@ -278,15 +257,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde + push!(varnames, vsym(L)) return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)... + ) end """ @@ -307,26 +290,26 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.tilde_observe!)( + $left = $(DynamicPPL.tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -351,26 +334,26 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.dot_tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $left, $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe!)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) @@ -413,26 +396,33 @@ function build_output(modelinfo, linenumbernode) ## Build the model function. - # Extract the named tuple expression of all arguments and the default values. - allargs_namedtuple = modelinfo[:allargs_namedtuple] - defaults_namedtuple = modelinfo[:defaults_namedtuple] + # Extract the named tuple expression of all arguments + allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] + allargs_wrapped = [ + x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Parameter)($x)) + for x in modelinfo[:allargs_syms] + ] + allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] + allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) + # modelinfo[:allargs_namedtuple] = to_namedtuple_expr(modelinfo[:allargs_syms], args_vals) # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] + @gensym evaluator # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` # to the call site modeldef[:body] = MacroTools.@q begin $(linenumbernode) + # $(observation_checks...) $evaluator = $(MacroTools.combinedef(evaluatordef)) + $(allargs_decls...) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ) + $allargs_namedtuple) end return :($(Base).@__doc__ $(MacroTools.combinedef(modeldef))) @@ -442,7 +432,7 @@ function warn_empty(body) if all(l -> isa(l, LineNumberNode), body.args) @warn("Model definition seems empty, still continue.") end - return nothing + return end """ diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..0ab8b4616 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x # assume """ diff --git a/src/model.jl b/src/model.jl index 9ec047a44..472e37e5e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,60 +1,23 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} + struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} + evaluator::F + arguments::NamedTuple{argumentnames,Targs} end -A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing -arguments `missings`. - -Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. - -An argument with a type of `Missing` will be in `missings` by default. However, in -non-traditional use-cases `missings` can be defined differently. All variables in `missings` -are treated as random variables rather than observations. - -The default arguments are used internally when constructing instances of the same model with -different arguments. +A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. # Examples ```julia -julia> Model(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple()) - -julia> Model(f, (x = 1.0, y = 2.0), (x = 42,)) -Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) - -julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings -Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) +TODO ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram +struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} - - """ - Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) - - Create a model of name `name` with evaluation function `f` and missing arguments - overwritten by `missings`. - """ - function Model{missings}( - name::Symbol, - f::F, - args::NamedTuple{argnames,Targs}, - defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults - ) - end + # code::Expr + evaluator::F + arguments::NamedTuple{argumentnames,Targs} end """ @@ -66,13 +29,108 @@ from `args`. Default arguments `defaults` are used internally when constructing instances of the same model with different arguments. """ -@generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() -) where {F,argnames,Targs} - missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) +# @generated function Model( +# name::Symbol, +# f::F, +# args::NamedTuple{argnames,Targs} +# ) where {F,argnames,Targs} +# # missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) +# return :(Modelngs}(name, f, args, defaults)) +# end + + +abstract type Argument{T} end +struct Observation{T} <: Argument{T} + value::T +end +struct Constant{T} <: Argument{T} + value::T end +""" + isobservation(vn, model) + +Check whether the value of the expression `vn` is a real observation in the `model`. + +A variable is an observations if it is among the observation data of the model, an the corresponding +observation value is not `missing` (e.g., it could happen that the observation data contain `x = +[missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) +""" +@generated function isobservation( + vn::VarName{s}, + model::Model{_F, argnames} +) where {s, _F, argnames} + if s in argnames + return :(isobservation(vn, getproperty(model.arguments, $(Meta.quot(s))))) + else + return :(false) + end +end +isobservation(::VarName, ::Parameter) = false +isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs, vn.indexing)) +isobservation(vn::VarName, obs::Observation{Missing}) = false + + +# """ +# @ConditionedModel{; obs1::Type1, obs2::Type2, ...} +# @ConditionedModel{f, parameternames, Tparams; obs1::Type1, obs2::Type2, ...} + +# Macro with more convenient syntax for declaring `Model` types with observations (similar to the +# `Base.@NamedTuple` macro). The observations to the parameters part of the braces: +# `@ConditionedModel{; x::Int, y}`. Type annotations can be omitted, in which case the type is +# defaulted to `Any`. + +# The non-parameters part can be used to match the other type arguments of `Model`: the evaluator +# function type `F`, and the `parameternames` and their type tuple `Tparams`. +# """ +# macro ConditionedModel(ex) +# # Code adapted from Base.@NamedTuple macro; parameter lists in `:braces` expressions do work: +# # julia> :(@bla{f; x, y}).args +# # 3-element Array{Any,1}: +# # Symbol("@bla") +# # :(#= REPL[55]:1 =#) +# # :({$(Expr(:parameters, :x, :y)), f}) + +# Meta.isexpr(ex, :braces) || throw(ArgumentError("@ConditionedModel expects {;...}")) +# decls = filter(e -> !(e isa LineNumberNode), ex.args) +# Meta.isexpr(decls[1], :parameters) || throw(ArgumentError("@ConditionedModel expects {;...}")) +# cond_part = decls[1].args +# types_part = decls[2:end] +# all(e -> e isa Symbol || Meta.isexpr(e, :(::)), cond_part) || +# throw(ArgumentError("@ConditionedModel must contain a sequence of name or name::type expressions")) +# obsvars = [QuoteNode(e isa Symbol ? e : e.args[1]) for e in cond_part] +# obstypes = [esc(e isa Symbol ? :Any : e.args[2]) for e in cond_part] +# _f = esc(get(types_part, 1, :Any)) +# _parameternames = esc(get(types_part, 2, :Any)) +# _tparams = esc(get(types_part, 3, :Any)) + +# return :($(DynamicPPL.Model){ +# $_f, +# $_parameternames, +# ($(obsvars...),), +# $_tparams, +# Tuple{$(obstypes...)} +# }) +# end + +# """ +# GenerativeModel{F, parameters, TParams} + +# Type alias for models without observations. +# """ +# const GenerativeModel{F, parameternames, Tparams} = @ConditionedModel{F, parameternames, Tparams;} + +function Base.show(io::IO, model::Model) + println(io, "Model ", model.name, " given") + print(io, " parameters ") + join(io, getparameternames(model), ", ") + println() + print(io, " observations ") + join(io, getobservationnames(model), ", ") + # println(_pretty(model.code)) +end + + """ (model::Model)([rng, varinfo, sampler, context]) @@ -149,9 +207,9 @@ function evaluate_threadsafe(model, varinfo, context) end """ - _evaluate(model::Model, varinfo, context) + _evaluate(rng, model::Model, varinfo, sampler, context) -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context @@ -161,18 +219,24 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf end """ - getargnames(model::Model) + getparameternames(model::Model) Get a tuple of the argument names of the `model`. """ -getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames +@generated function getparameternames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + param_indices = filter(i -> Targs.parameters[i] <: Parameter, eachindex(Targs.parameters)) + return argnames[param_indices] +end """ - getmissings(model::Model) + getparameternames(model::Model) -Get a tuple of the names of the missing arguments of the `model`. +Get a tuple of the observation names of the `model`. """ -getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings +@generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + obs_indices = filter(i -> Targs.parameters[i] <: Observation, eachindex(Targs.parameters)) + return argnames[obs_indices] +end """ nameof(model::Model) @@ -182,40 +246,50 @@ Get the name of the `model` as `Symbol`. Base.nameof(model::Model) = model.name """ - logjoint(model::Model, varinfo::AbstractVarInfo) + logdensity(model::Model, varinfo::AbstractVarInfo) -Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Return the log joint probability of variables in `varinfo` for the probabilistic `model`. -See [`logjoint`](@ref) and [`loglikelihood`](@ref). +See [`logprior`](@ref) and [`loglikelihood`](@ref). """ -function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) +function AbstractPPL.logdensity(model::Model, varinfo::AbstractVarInfo) + model(varinfo, SampleFromPrior(), DefaultContext()) return getlogp(varinfo) end -""" - logprior(model::Model, varinfo::AbstractVarInfo) - -Return the log prior probability of variables `varinfo` for the probabilistic `model`. +function AbstractPPL.decondition(model::Model, name = Symbol(model.name, "_joint")) + return Model(name, model.evaluator, model.parameters, NamedTuple()) +end -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). -""" -function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, PriorContext()) - return getlogp(varinfo) +function AbstractPPL.condition(model::Model, observations, name = Symbol(model.name, "_cond")) + return Model(name, model.evaluator, model.parameters, merge(model.observations, observations)) end -""" - loglikelihood(model::Model, varinfo::AbstractVarInfo) -Return the log likelihood of variables `varinfo` for the probabilistic `model`. +# """ +# logprior(model::Model, varinfo::AbstractVarInfo) + +# Return the log prior probability of variables `varinfo` for the probabilistic `model`. + +# See also [`logjoint`](@ref) and [`loglikelihood`](@ref). +# """ +# function logprior(model::Model, varinfo::AbstractVarInfo) +# model(varinfo, SampleFromPrior(), PriorContext()) +# return getlogp(varinfo) +# end + +# """ +# loglikelihood(model::Model, varinfo::AbstractVarInfo) + +# Return the log likelihood of variables `varinfo` for the probabilistic `model`. + +# See also [`logjoint`](@ref) and [`logprior`](@ref). +# """ +# function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) +# model(varinfo, SampleFromPrior(), LikelihoodContext()) +# return getlogp(varinfo) +# # end -See also [`logjoint`](@ref) and [`logprior`](@ref). -""" -function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, LikelihoodContext()) - return getlogp(varinfo) -end """ generated_quantities(model::Model, chain::AbstractChains) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..da66c7ddd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,3 +169,6 @@ end ####################### collectmaybe(x) = x collectmaybe(x::Base.AbstractSet) = collect(x) + +_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple{}) = x diff --git a/src/varname.jl b/src/varname.jl index 343bb0da8..c17441b3d 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -15,30 +15,5 @@ function subsumes_string(u::String, v::String, u_indexing=u * "[") return u == v || startswith(v, u_indexing) end -""" - inargnames(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is an argument of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inargnames(::VarName{s}, ::Model{_F,argnames}) where {s,argnames,_F} - return s in argnames -end - -""" - inmissings(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is a statically declared unobserved variable -of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inmissings( - ::VarName{s}, ::Model{_F,_a,_T,missings} -) where {s,missings,_F,_a,_T} - return s in missings -end - # HACK: Type-piracy. Is this really the way to go? AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym