From 10a15054246648416bbf92e47a013b59125562c5 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 21 Jun 2021 13:35:26 +0200 Subject: [PATCH 01/13] Different approach to how observations/missings are stored in the model --- src/compiler.jl | 162 +++++++++++++-------------------- src/context_implementations.jl | 2 - src/model.jl | 125 ++++++++++++++----------- src/utils.jl | 3 + src/varname.jl | 27 ------ 5 files changed, 138 insertions(+), 181 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..7253c44dc 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,40 +1,6 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__) const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) -""" - isassumption(expr) - -Return an expression that can be evaluated to check if `expr` is an assumption in the -model. - -Let `expr` be `:(x[1])`. It is an assumption in the following cases: - 1. `x` is not among the input data to the model, - 2. `x` is among the input data to the model but with a value `missing`, or - 3. `x` is among the input data to the model with a value other than missing, - but `x[1] === missing`. - -When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -""" -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - - return quote - let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - # Evaluate the LHS - $expr === missing - end - end - end -end - -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) """ isliteral(expr) @@ -137,8 +103,14 @@ end function model(mod, linenumbernode, expr, warn) modelinfo = build_model_info(expr) - # Generate main body - modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) + # Generate main body and find all variable symbols + modelinfo[:body], modelinfo[:varnames] = generate_mainbody( + mod, modelinfo[:modeldef][:body], warn + ) + + # extract parameters and observations from that + modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms]) + modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames]) return build_output(modelinfo, linenumbernode) end @@ -167,8 +139,6 @@ function build_model_info(input_expr) modelinfo = Dict( :allargs_exprs => [], :allargs_syms => [], - :allargs_namedtuple => NamedTuple(), - :defaults_namedtuple => NamedTuple(), :modeldef => modeldef, ) return modelinfo @@ -196,28 +166,10 @@ function build_model_info(input_expr) x_ => x end end - - # Build named tuple expression of the argument symbols and variables of the same name. - allargs_namedtuple = to_namedtuple_expr(allargs_syms) - - # Extract default values of the positional and keyword arguments. - default_syms = [] - default_vals = [] - for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults) - if val !== NO_DEFAULT - push!(default_syms, sym) - push!(default_vals, val) - end - end - - # Build named tuple expression of the argument symbols with default values. - defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals) - + modelinfo = Dict( :allargs_exprs => allargs_exprs, :allargs_syms => allargs_syms, - :allargs_namedtuple => allargs_namedtuple, - :defaults_namedtuple => defaults_namedtuple, :modeldef => modeldef, ) @@ -233,43 +185,50 @@ Generate the body of the main evaluation function from expression `expr` and arg If `warn` is true, a warning is displayed if internal variables are used in the model definition. """ -generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn) +function generate_mainbody(mod, expr, warn) + varnames = Symbol[] + body = generate_mainbody!(mod, Symbol[], varnames, expr, warn) + return body, varnames +end -generate_mainbody!(mod, found, x, warn) = x -function generate_mainbody!(mod, found, sym::Symbol, warn) +generate_mainbody!(mod, found_internals, varnames, x, warn) = x +function generate_mainbody!(mod, found_internals, sym::Symbol, warn) if sym in DEPRECATED_INTERNALNAMES newsym = Symbol(:_, sym, :__) Base.depwarn( "internal variable `$sym` is deprecated, use `$newsym` instead.", :generate_mainbody!, ) - return generate_mainbody!(mod, found, newsym, warn) + return generate_mainbody!(mod, found_internals, newsym, warn) end - if warn && sym in INTERNALNAMES && sym ∉ found + if warn && sym in INTERNALNAMES && sym ∉ found_internals @warn "you are using the internal variable `$sym`" - push!(found, sym) + push!(found_internals, sym) end return sym end -function generate_mainbody!(mod, found, expr::Expr, warn) +function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) - return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) + return generate_mainbody!( + mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn + ) end # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde + push!(varnames, vsym(L)) return Base.remove_linenums!( generate_dot_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end @@ -278,15 +237,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde + push!(varnames, vsym(L)) return Base.remove_linenums!( generate_tilde( - generate_mainbody!(mod, found, L, warn), - generate_mainbody!(mod, found, R, warn), + generate_mainbody!(mod, found_internals, varnames, L, warn), + generate_mainbody!(mod, found_internals, varnames, R, warn), ), ) end - return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) + return Expr( + expr.head, + map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)..., + ) end """ @@ -307,26 +270,26 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.tilde_observe!)( + $left = $(DynamicPPL.tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -351,26 +314,26 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds isobservation return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + $isobservation = $(DynamicPPL.isobservation)($vn, __model__) + if $isobservation + $(DynamicPPL.dot_tilde_observe!)( __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $left, $vn - )..., + $(DynamicPPL.check_tilde_rhs)($right), + $left, + $vn, $inds, __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe!)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) @@ -413,10 +376,15 @@ function build_output(modelinfo, linenumbernode) ## Build the model function. - # Extract the named tuple expression of all arguments and the default values. - allargs_namedtuple = modelinfo[:allargs_namedtuple] - defaults_namedtuple = modelinfo[:defaults_namedtuple] - + # Extract the named tuple expression of all arguments + allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] + allargs_wrapped = [ + x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Constant)($x)) + for x in modelinfo[:allargs_syms] + ] + allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] + allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) + # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. modeldef = modelinfo[:modeldef] @@ -427,11 +395,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) + $(allargs_decls...) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, - $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..0ab8b4616 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x # assume """ diff --git a/src/model.jl b/src/model.jl index 9ec047a44..fcf079898 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,23 +1,10 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} + struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} + evaluator::F + arguments::NamedTuple{argumentnames,Targs} end - -A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing -arguments `missings`. - -Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. - -An argument with a type of `Missing` will be in `missings` by default. However, in -non-traditional use-cases `missings` can be defined differently. All variables in `missings` -are treated as random variables rather than observations. - -The default arguments are used internally when constructing instances of the same model with -different arguments. +A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. # Examples @@ -32,47 +19,55 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram +struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol - f::F - args::NamedTuple{argnames,Targs} - defaults::NamedTuple{defaultnames,Tdefaults} - - """ - Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) - - Create a model of name `name` with evaluation function `f` and missing arguments - overwritten by `missings`. - """ - function Model{missings}( - name::Symbol, - f::F, - args::NamedTuple{argnames,Targs}, - defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults - ) - end + # code::Expr + evaluator::F + arguments::NamedTuple{argumentnames,Targs} +end + + +abstract type Argument{T} end +struct Observation{T} <: Argument{T} + value::T +end +struct Constant{T} <: Argument{T} + value::T end """ - Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()]) + isobservation(vn, model) -Create a model of name `name` with evaluation function `f` and missing arguments deduced -from `args`. +Check whether the value of the expression `vn` is a real observation in the `model`. -Default arguments `defaults` are used internally when constructing instances of the same -model with different arguments. +A variable is an observation if it is among the arguments data of the model, and the corresponding +observation value is not `missing` (e.g., it could happen that the arguments contain `x = +[missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) """ -@generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() -) where {F,argnames,Targs} - missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) +@generated function isobservation( + vn::VarName{s}, + model::Model{_F, argnames} +) where {s, _F, argnames} + if s in argnames + return :(isobservation(vn, model.arguments.$s)) + else + return :(false) + end +end +isobservation(::VarName, ::Constant) = false +isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) +isobservation(vn::VarName, obs::Observation{Missing}) = false + +function Base.show(io::IO, model::Model) + println(io, "Model ", model.name, " given") + print(io, " constants ") + join(io, getconstantnames(model), ", ") + println() + print(io, " observations ") + return join(io, getobservationnames(model), ", ") end + """ (model::Model)([rng, varinfo, sampler, context]) @@ -161,19 +156,39 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf end """ - getargnames(model::Model) + getargumentnames(model::Model) Get a tuple of the argument names of the `model`. """ -getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames +getargumentnames(model::Model{_F,argnames}) where {argnames,_F} = argnames +Base.@deprecate getargnames(model) getargumentnames(model) -""" - getmissings(model::Model) +function _filter_arguments(argnames, Targs, ::Type{T}) where {T} + return filter(i -> Targs.parameters[i] <: T, eachindex(Targs.parameters)) +end -Get a tuple of the names of the missing arguments of the `model`. """ -getmissings(model::Model{_F,_a,_d,missings}) where {missings,_F,_a,_d} = missings + getconstantnames(model::Model) +Get a tuple of the argument names of the `model`. +""" +@generated function getconstantnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + return argnames[_filter_arguments(argnames, Targs, Constant)] +end +@generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + return argnames[_filter_arguments(argnames, Targs, Observation)] +end + +@generated function getconstants(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + args = _filter_arguments(argnames, Targs, Constant) + return Expr(:tuple, (:(model.arguments.$arg) for arg in arguments)...) +end + +@generated function getobservations(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} + args = _filter_arguments(argnames, Targs, Observation) + return Expr(:tuple, (:(model.arguments.$arg) for arg in arguments)...) +end + """ nameof(model::Model) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..da66c7ddd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,3 +169,6 @@ end ####################### collectmaybe(x) = x collectmaybe(x::Base.AbstractSet) = collect(x) + +_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple{}) = x diff --git a/src/varname.jl b/src/varname.jl index 343bb0da8..59e051ad8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -15,30 +15,3 @@ function subsumes_string(u::String, v::String, u_indexing=u * "[") return u == v || startswith(v, u_indexing) end -""" - inargnames(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is an argument of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inargnames(::VarName{s}, ::Model{_F,argnames}) where {s,argnames,_F} - return s in argnames -end - -""" - inmissings(varname::VarName, model::Model) - -Statically check whether the variable of name `varname` is a statically declared unobserved variable -of the `model`. - -Possibly existing indices of `varname` are neglected. -""" -@generated function inmissings( - ::VarName{s}, ::Model{_F,_a,_T,missings} -) where {s,missings,_F,_a,_T} - return s in missings -end - -# HACK: Type-piracy. Is this really the way to go? -AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From 781a86d930a6a6913667c2b39ae0e3ca402f8eac Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 21 Jun 2021 17:54:59 +0200 Subject: [PATCH 02/13] Restructure and fix some things --- src/DynamicPPL.jl | 9 +++-- src/compiler.jl | 4 +-- src/model.jl | 87 ++++++++++++++++++++++++++++++----------------- 3 files changed, 64 insertions(+), 36 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..8661561dc 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -67,9 +67,14 @@ export AbstractVarInfo, vectorize, # Model Model, - getmissings, - getargnames, + getargumentnames, + getarguments, + getconstantnames, + getconstants, + getobservationnames, + getobservations, generated_quantities, + isobservation, # Samplers Sampler, SampleFromPrior, diff --git a/src/compiler.jl b/src/compiler.jl index 7253c44dc..c5e5aa556 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -224,7 +224,7 @@ function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing L, R = args_dottilde - push!(varnames, vsym(L)) + !isliteral(L) && push!(varnames, vsym(L)) return Base.remove_linenums!( generate_dot_tilde( generate_mainbody!(mod, found_internals, varnames, L, warn), @@ -237,7 +237,7 @@ function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn) args_tilde = getargs_tilde(expr) if args_tilde !== nothing L, R = args_tilde - push!(varnames, vsym(L)) + !isliteral(L) && push!(varnames, vsym(L)) return Base.remove_linenums!( generate_tilde( generate_mainbody!(mod, found_internals, varnames, L, warn), diff --git a/src/model.jl b/src/model.jl index fcf079898..99b711529 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,23 +1,28 @@ + +""" + abstract type Argument{T} end + +Parametric wrapper type for model arguments. +""" +abstract type Argument{T} end + +struct Observation{T} <: Argument{T} + value::T +end + +struct Constant{T} <: Argument{T} + value::T +end + + """ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol evaluator::F arguments::NamedTuple{argumentnames,Targs} end -A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. - -# Examples - -```julia -julia> Model(f, (x = 1.0, y = 2.0)) -Model{typeof(f),(:x, :y),(),(),Tuple{Float64,Float64},Tuple{}}(f, (x = 1.0, y = 2.0), NamedTuple()) -julia> Model(f, (x = 1.0, y = 2.0), (x = 42,)) -Model{typeof(f),(:x, :y),(:x,),(),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) - -julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings -Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) -``` +A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. """ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram name::Symbol @@ -27,14 +32,6 @@ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram end -abstract type Argument{T} end -struct Observation{T} <: Argument{T} - value::T -end -struct Constant{T} <: Argument{T} - value::T -end - """ isobservation(vn, model) @@ -58,13 +55,18 @@ isobservation(::VarName, ::Constant) = false isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) isobservation(vn::VarName, obs::Observation{Missing}) = false -function Base.show(io::IO, model::Model) + +function Base.show(io::IO, ::MIME"text/plain", model::Model) println(io, "Model ", model.name, " given") print(io, " constants ") join(io, getconstantnames(model), ", ") println() print(io, " observations ") - return join(io, getobservationnames(model), ", ") + join(io, getobservationnames(model), ", ") +end + +function Base.show(io::IO, model::Model) + println(io, "$(model.name)$(getarguments(model))") end @@ -148,45 +150,66 @@ end Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ -@generated function _evaluate( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) +function _evaluate(model::Model, varinfo, context) + unwrap_args = matchingvalue.((context,), (varinfo,), Tuple(getarguments(model))) + return model.evaluator(model, varinfo, context, unwrap_args...) end """ getargumentnames(model::Model) -Get a tuple of the argument names of the `model`. +Return a tuple of the argument names of the `model`. """ getargumentnames(model::Model{_F,argnames}) where {argnames,_F} = argnames Base.@deprecate getargnames(model) getargumentnames(model) +""" + getargumentnames(model::Model) + +Return the arguments passed to the model. +""" +getarguments(model::Model) = map(arg -> arg.value, model.arguments) + function _filter_arguments(argnames, Targs, ::Type{T}) where {T} return filter(i -> Targs.parameters[i] <: T, eachindex(Targs.parameters)) end """ getconstantnames(model::Model) -Get a tuple of the argument names of the `model`. + +Return a tuple of the names of the observations passed to `model`. """ @generated function getconstantnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} return argnames[_filter_arguments(argnames, Targs, Constant)] end +""" + getconstantnames(model::Model) + +Return a tuple of the names of the observations passed to `model`. +""" @generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} return argnames[_filter_arguments(argnames, Targs, Observation)] end +""" + getconstants(model::Model) + +Return a tuple of the constants passed to `model`. +""" @generated function getconstants(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} args = _filter_arguments(argnames, Targs, Constant) - return Expr(:tuple, (:(model.arguments.$arg) for arg in arguments)...) + return Expr(:tuple, (:(model.arguments.$arg.value) for arg in arguments)...) end +""" + getobservationnames(model::Model) + +Return a tuple of the observations passed to `model`. +""" @generated function getobservations(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} args = _filter_arguments(argnames, Targs, Observation) - return Expr(:tuple, (:(model.arguments.$arg) for arg in arguments)...) + return Expr(:tuple, (:(model.arguments.$arg.value) for arg in arguments)...) end """ From 8244cdb06425dd55fc69ad41a83aae700422e32a Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 21 Jun 2021 17:55:08 +0200 Subject: [PATCH 03/13] Adapt test cases --- test/loglikelihoods.jl | 3 ++- test/runtests.jl | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 74fb88d70..f23bb1529 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -110,7 +110,8 @@ const gdemo_models = ( continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + args = getarguments(m) + loglikelihood = if length(keys(lls)) == 1 && length(args.x) == 1 # Only have one observation, so we need to double it # for comparison with other models. 2 * sum(lls[first(keys(lls))]) diff --git a/test/runtests.jl b/test/runtests.jl index d83be0eea..6a9264e23 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,8 @@ using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") -const GROUP = get(ENV, "GROUP", "All") +# const GROUP = get(ENV, "GROUP", "All") +const GROUP = "DynamicPPL" Random.seed!(100) @@ -36,7 +37,8 @@ include("test_util.jl") include("varinfo.jl") include("model.jl") include("sampler.jl") - include("prob_macro.jl") + # include("prob_macro.jl") + @warn "Prob macro tests turned off!!!!" include("independence.jl") include("distribution_wrappers.jl") include("contexts.jl") From 66b228c1c7b968cc6f82bec818c138006d5454d6 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 21 Jun 2021 18:13:24 +0200 Subject: [PATCH 04/13] Remove fixed test group --- test/runtests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6a9264e23..72f8555f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,8 +22,7 @@ using DynamicPPL: getargs_dottilde, getargs_tilde, Selector const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") -# const GROUP = get(ENV, "GROUP", "All") -const GROUP = "DynamicPPL" +const GROUP = get(ENV, "GROUP", "All") Random.seed!(100) From 0a074bc3f367b479505a2f72e0e6ddd01a4ef16f Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 28 Jun 2021 17:24:08 +0200 Subject: [PATCH 05/13] Update src/model.jl Co-authored-by: Tor Erlend Fjelde --- src/model.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/model.jl b/src/model.jl index 99b711529..850de913e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -41,15 +41,8 @@ A variable is an observation if it is among the arguments data of the model, and observation value is not `missing` (e.g., it could happen that the arguments contain `x = [missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) """ -@generated function isobservation( - vn::VarName{s}, - model::Model{_F, argnames} -) where {s, _F, argnames} - if s in argnames - return :(isobservation(vn, model.arguments.$s)) - else - return :(false) - end +function isobservation(vn::VarName{s}, model::Model{<:Any, argnames}) where {s, argnames} + return (s in argnames) || isobservation(vn, getproperty(model.arguments, s)) end isobservation(::VarName, ::Constant) = false isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) From e79a0d089e346dd4faa9020b9d43e1cf045981a7 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Mon, 28 Jun 2021 20:05:57 +0200 Subject: [PATCH 06/13] Fix some things and implement some suggestions --- src/DynamicPPL.jl | 6 ---- src/compiler.jl | 5 ++- src/model.jl | 73 +++++++++++++++++++++++++++--------------- test/loglikelihoods.jl | 2 +- 4 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8661561dc..0434438b5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -68,13 +68,7 @@ export AbstractVarInfo, # Model Model, getargumentnames, - getarguments, - getconstantnames, - getconstants, - getobservationnames, - getobservations, generated_quantities, - isobservation, # Samplers Sampler, SampleFromPrior, diff --git a/src/compiler.jl b/src/compiler.jl index c5e5aa556..2a60c4500 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -108,9 +108,8 @@ function model(mod, linenumbernode, expr, warn) mod, modelinfo[:modeldef][:body], warn ) - # extract parameters and observations from that - modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms]) - modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames]) + # extract observations from that + modelinfo[:obsnames] = modelinfo[:allargs_syms] ∩ modelinfo[:varnames] return build_output(modelinfo, linenumbernode) end diff --git a/src/model.jl b/src/model.jl index 850de913e..36a3a19de 100644 --- a/src/model.jl +++ b/src/model.jl @@ -41,21 +41,42 @@ A variable is an observation if it is among the arguments data of the model, and observation value is not `missing` (e.g., it could happen that the arguments contain `x = [missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) """ -function isobservation(vn::VarName{s}, model::Model{<:Any, argnames}) where {s, argnames} - return (s in argnames) || isobservation(vn, getproperty(model.arguments, s)) +function isobservation(vn::VarName{s}, model::Model{<:Any,argnames}) where {s,argnames} + return (s in argnames) && isobservation(vn, getproperty(model.arguments, s)) end isobservation(::VarName, ::Constant) = false -isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) isobservation(vn::VarName, obs::Observation{Missing}) = false +isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) function Base.show(io::IO, ::MIME"text/plain", model::Model) + constants, observations = VarName[VarName{c}() for c in getconstantnames(model)], VarName[] + for (obs, value) in pairs(getobservations(model)) + if value isa AbstractArray + all_indices = CartesianIndices(value) + missing_indices = filter(ix -> ismissing(value[ix]), all_indices) + if isempty(missing_indices) + # all observations given -- full variable observed + push!(observations, VarName{obs}()) + else + # mixed case -- indexed variables in both categories + observed_indices = setdiff(all_indices, missing_indices) + for ix in observed_indices + push!(observations, VarName{obs}((Tuple(ix),))) + end + end + else + complete_name = VarName{obs}() + !ismissing(value) && push!(observations, complete_name) + end + end + println(io, "Model ", model.name, " given") print(io, " constants ") - join(io, getconstantnames(model), ", ") + join(io, constants, ", ") println() print(io, " observations ") - join(io, getobservationnames(model), ", ") + join(io, observations, ", ") end function Base.show(io::IO, model::Model) @@ -144,8 +165,8 @@ end Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ function _evaluate(model::Model, varinfo, context) - unwrap_args = matchingvalue.((context,), (varinfo,), Tuple(getarguments(model))) - return model.evaluator(model, varinfo, context, unwrap_args...) + matched_args = map(arg -> matchingvalue(context, varinfo, arg), getarguments(model)) + return model.evaluator(model, varinfo, context, matched_args...) end """ @@ -153,18 +174,18 @@ end Return a tuple of the argument names of the `model`. """ -getargumentnames(model::Model{_F,argnames}) where {argnames,_F} = argnames +getargumentnames(model::Model{<:Any,argnames}) where {argnames} = argnames Base.@deprecate getargnames(model) getargumentnames(model) """ - getargumentnames(model::Model) + getarguments(model::Model) -Return the arguments passed to the model. +Return a `NamedTuple` of the arguments passed to the model. """ getarguments(model::Model) = map(arg -> arg.value, model.arguments) function _filter_arguments(argnames, Targs, ::Type{T}) where {T} - return filter(i -> Targs.parameters[i] <: T, eachindex(Targs.parameters)) + return [arg for (arg, Targ) in zip(argnames, Targs.parameters) if Targ <: T] end """ @@ -172,37 +193,39 @@ end Return a tuple of the names of the observations passed to `model`. """ -@generated function getconstantnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} - return argnames[_filter_arguments(argnames, Targs, Constant)] +@generated function getconstantnames(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} + return _filter_arguments(argnames, Targs, Constant) end """ - getconstantnames(model::Model) + getobservationnames(model::Model) Return a tuple of the names of the observations passed to `model`. """ -@generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} - return argnames[_filter_arguments(argnames, Targs, Observation)] +@generated function getobservationnames(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} + return _filter_arguments(argnames, Targs, Observation) end """ getconstants(model::Model) -Return a tuple of the constants passed to `model`. +Return a `NamedTuple` of the constants passed to `model`. """ -@generated function getconstants(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} - args = _filter_arguments(argnames, Targs, Constant) - return Expr(:tuple, (:(model.arguments.$arg.value) for arg in arguments)...) +@generated function getconstants(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} + args = Tuple(_filter_arguments(argnames, Targs, Constant)) + values = [:(model.arguments.$arg.value) for arg in args] + return :(NamedTuple{$args}(($(values...),))) end """ - getobservationnames(model::Model) + getobservations(model::Model) -Return a tuple of the observations passed to `model`. +Return a `NamedTuple` of the observations passed to `model` (without respecting `missing`s). """ -@generated function getobservations(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} - args = _filter_arguments(argnames, Targs, Observation) - return Expr(:tuple, (:(model.arguments.$arg.value) for arg in arguments)...) +@generated function getobservations(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} + args = Tuple(_filter_arguments(argnames, Targs, Observation)) + values = [:(model.arguments.$arg.value) for arg in args] + return :(NamedTuple{$args}(($(values...),))) end """ diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index f23bb1529..c4c1d7403 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -110,7 +110,7 @@ const gdemo_models = ( continue end - args = getarguments(m) + args = DynamicPPL.getarguments(m) loglikelihood = if length(keys(lls)) == 1 && length(args.x) == 1 # Only have one observation, so we need to double it # for comparison with other models. From 852afb9e770fa562a127154f2e68ca66419b6a92 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sat, 3 Jul 2021 13:29:14 +0200 Subject: [PATCH 07/13] Rename Observation -> Variable, simplify argument getters --- src/compiler.jl | 2 +- src/model.jl | 108 +++++++++++++++++++++--------------------------- 2 files changed, 47 insertions(+), 63 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2a60c4500..f85d10f05 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -378,7 +378,7 @@ function build_output(modelinfo, linenumbernode) # Extract the named tuple expression of all arguments allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] allargs_wrapped = [ - x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Observation)($x)) : :($(DynamicPPL.Constant)($x)) + x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Variable)($x)) : :($(DynamicPPL.Constant)($x)) for x in modelinfo[:allargs_syms] ] allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] diff --git a/src/model.jl b/src/model.jl index 36a3a19de..0d9ff43e9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,19 +1,25 @@ """ - abstract type Argument{T} end + abstract type Argument{T,isdefault} end Parametric wrapper type for model arguments. """ -abstract type Argument{T} end +abstract type Argument{T,isdefault} end -struct Observation{T} <: Argument{T} +struct Variable{T,isdefault} <: Argument{T,isdefault} value::T end -struct Constant{T} <: Argument{T} +Variable{isdefault}(x) where {isdefault} = Variable{typeof(x), false}(x) +Variable(x) = Variable{false}(x) + +struct Constant{T,isdefault} <: Argument{T,isdefault} value::T end +Constant{isdefault}(x) where {isdefault} = Constant{typeof(x), false}(x) +Constant(x) = Constant{false}(x) + """ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram @@ -45,38 +51,39 @@ function isobservation(vn::VarName{s}, model::Model{<:Any,argnames}) where {s,ar return (s in argnames) && isobservation(vn, getproperty(model.arguments, s)) end isobservation(::VarName, ::Constant) = false -isobservation(vn::VarName, obs::Observation{Missing}) = false -isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs.value, vn.indexing)) +isobservation(vn::VarName, obs::Variable{Missing}) = false +isobservation(vn::VarName, obs::Variable) = !ismissing(_getindex(obs.value, vn.indexing)) function Base.show(io::IO, ::MIME"text/plain", model::Model) - constants, observations = VarName[VarName{c}() for c in getconstantnames(model)], VarName[] - for (obs, value) in pairs(getobservations(model)) + constants = VarName[VarName{c}() for c in getargumentnames(model, Constant)] + observed_variables = VarName[] + for (var, value) in pairs(getarguments(model, Variable)) if value isa AbstractArray all_indices = CartesianIndices(value) missing_indices = filter(ix -> ismissing(value[ix]), all_indices) if isempty(missing_indices) - # all observations given -- full variable observed - push!(observations, VarName{obs}()) + # all indexed given -- full variable observed + push!(observed_variables, VarName{var}()) else # mixed case -- indexed variables in both categories observed_indices = setdiff(all_indices, missing_indices) for ix in observed_indices - push!(observations, VarName{obs}((Tuple(ix),))) + push!(observed_variables, VarName{var}((Tuple(ix),))) end end else - complete_name = VarName{obs}() - !ismissing(value) && push!(observations, complete_name) + complete_name = VarName{var}() + !ismissing(value) && push!(observed_variables, complete_name) end end println(io, "Model ", model.name, " given") - print(io, " constants ") + print(io, " constants ") join(io, constants, ", ") - println() - print(io, " observations ") - join(io, observations, ", ") + println(io) + print(io, " observed variables ") + join(io, observed_variables, ", ") end function Base.show(io::IO, model::Model) @@ -170,64 +177,41 @@ function _evaluate(model::Model, varinfo, context) end """ - getargumentnames(model::Model) + getargumentnames(model::Model, [::Type{T}]) -Return a tuple of the argument names of the `model`. +Return a tuple of the argument names of the `model`. The second argument can be used to filter +the types of arguments (constant, variable, default) by passing an `Argument` subtype. """ getargumentnames(model::Model{<:Any,argnames}) where {argnames} = argnames -Base.@deprecate getargnames(model) getargumentnames(model) - -""" - getarguments(model::Model) - -Return a `NamedTuple` of the arguments passed to the model. -""" -getarguments(model::Model) = map(arg -> arg.value, model.arguments) - -function _filter_arguments(argnames, Targs, ::Type{T}) where {T} - return [arg for (arg, Targ) in zip(argnames, Targs.parameters) if Targ <: T] +@generated function getargumentnames( + model::Model{<:Any,argnames,Targs}, + ::Type{T} +) where {argnames,Targs,T} + return _getargumentnames(argnames, Targs, T) end - -""" - getconstantnames(model::Model) - -Return a tuple of the names of the observations passed to `model`. -""" -@generated function getconstantnames(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} - return _filter_arguments(argnames, Targs, Constant) +function _getargumentnames(argnames, Targs, ::Type{T}) where {T} + return Tuple([n for (n, Targ) in zip(argnames, Targs.parameters) if Targ <: T]) end -""" - getobservationnames(model::Model) - -Return a tuple of the names of the observations passed to `model`. -""" -@generated function getobservationnames(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} - return _filter_arguments(argnames, Targs, Observation) -end +Base.@deprecate getargnames(model) getargumentnames(model) """ - getconstants(model::Model) + getarguments(model::Model, [::Type{T}]) -Return a `NamedTuple` of the constants passed to `model`. +Return a `NamedTuple` of the constants passed to `model`. The second argument can be used to filter +the types of arguments (constant, variable, default) by passing an `Argument` subtype. """ -@generated function getconstants(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} - args = Tuple(_filter_arguments(argnames, Targs, Constant)) - values = [:(model.arguments.$arg.value) for arg in args] - return :(NamedTuple{$args}(($(values...),))) +getarguments(model::Model) = map(arg -> arg.value, model.arguments) +@generated function getarguments( + model::Model{<:Any,argnames,TArgs}, + ::Type{T} +) where {argnames,TArgs,T} + filtered_argnames = _getargumentnames(argnames, TArgs, T) + values = [:(model.arguments.$arg.value) for arg in filtered_argnames] + return :(NamedTuple{$filtered_argnames}(($(values...),))) end -""" - getobservations(model::Model) -Return a `NamedTuple` of the observations passed to `model` (without respecting `missing`s). -""" -@generated function getobservations(model::Model{<:Any,argnames,Targs}) where {argnames,Targs} - args = Tuple(_filter_arguments(argnames, Targs, Observation)) - values = [:(model.arguments.$arg.value) for arg in args] - return :(NamedTuple{$args}(($(values...),))) -end - """ nameof(model::Model) From ab8ce2c124bb0dc15c10462209ec2891b47bbecf Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sat, 3 Jul 2021 14:03:06 +0200 Subject: [PATCH 08/13] Move extraction functions out of show --- src/model.jl | 86 +++++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 42 deletions(-) diff --git a/src/model.jl b/src/model.jl index 0d9ff43e9..614342953 100644 --- a/src/model.jl +++ b/src/model.jl @@ -37,53 +37,13 @@ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram arguments::NamedTuple{argumentnames,Targs} end - -""" - isobservation(vn, model) - -Check whether the value of the expression `vn` is a real observation in the `model`. - -A variable is an observation if it is among the arguments data of the model, and the corresponding -observation value is not `missing` (e.g., it could happen that the arguments contain `x = -[missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) -""" -function isobservation(vn::VarName{s}, model::Model{<:Any,argnames}) where {s,argnames} - return (s in argnames) && isobservation(vn, getproperty(model.arguments, s)) -end -isobservation(::VarName, ::Constant) = false -isobservation(vn::VarName, obs::Variable{Missing}) = false -isobservation(vn::VarName, obs::Variable) = !ismissing(_getindex(obs.value, vn.indexing)) - - function Base.show(io::IO, ::MIME"text/plain", model::Model) - constants = VarName[VarName{c}() for c in getargumentnames(model, Constant)] - observed_variables = VarName[] - for (var, value) in pairs(getarguments(model, Variable)) - if value isa AbstractArray - all_indices = CartesianIndices(value) - missing_indices = filter(ix -> ismissing(value[ix]), all_indices) - if isempty(missing_indices) - # all indexed given -- full variable observed - push!(observed_variables, VarName{var}()) - else - # mixed case -- indexed variables in both categories - observed_indices = setdiff(all_indices, missing_indices) - for ix in observed_indices - push!(observed_variables, VarName{var}((Tuple(ix),))) - end - end - else - complete_name = VarName{var}() - !ismissing(value) && push!(observed_variables, complete_name) - end - end - println(io, "Model ", model.name, " given") print(io, " constants ") - join(io, constants, ", ") + join(io, getconstants(model), ", ") println(io) print(io, " observed variables ") - join(io, observed_variables, ", ") + join(io, getobservedvariables(model), ", ") end function Base.show(io::IO, model::Model) @@ -176,6 +136,7 @@ function _evaluate(model::Model, varinfo, context) return model.evaluator(model, varinfo, context, matched_args...) end + """ getargumentnames(model::Model, [::Type{T}]) @@ -211,6 +172,47 @@ getarguments(model::Model) = map(arg -> arg.value, model.arguments) return :(NamedTuple{$filtered_argnames}(($(values...),))) end +""" + isobservation(vn, model) + +Check whether the value of the expression `vn` is a real observation in the `model`. + +A variable is an observation if it is among the arguments data of the model, and the corresponding +observation value is not `missing` (e.g., it could happen that the arguments contain `x = +[missing, 42]` -- then `x[1]` is not an observation, but `x[2]` is.) +""" +function isobservation(vn::VarName{s}, model::Model{<:Any,argnames}) where {s,argnames} + return (s in argnames) && isobservation(vn, getproperty(model.arguments, s)) +end +isobservation(::VarName, ::Constant) = false +isobservation(vn::VarName, obs::Variable{Missing}) = false +isobservation(vn::VarName, obs::Variable) = !ismissing(_getindex(obs.value, vn.indexing)) + +function getobservedvariables(model::Model) + observed_variables = VarName[] + for (var, value) in pairs(getarguments(model, Variable)) + if value isa AbstractArray + all_indices = CartesianIndices(value) + missing_indices = filter(ix -> ismissing(value[ix]), all_indices) + if isempty(missing_indices) + # all indexed given -- full variable observed + push!(observed_variables, VarName{var}()) + else + # mixed case -- indexed variables in both categories + observed_indices = setdiff(all_indices, missing_indices) + for ix in observed_indices + push!(observed_variables, VarName{var}((Tuple(ix),))) + end + end + else + complete_name = VarName{var}() + !ismissing(value) && push!(observed_variables, complete_name) + end + end + return observed_variables +end + +getconstants(model::Model) = VarName[VarName{c}() for c in getargumentnames(model, Constant)] """ nameof(model::Model) From 1979a64166ef6af445af20561fffd3d0dc38d85b Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sat, 3 Jul 2021 15:08:34 +0200 Subject: [PATCH 09/13] Properly store default arguments --- src/compiler.jl | 28 +++++++++++++++++----------- src/model.jl | 19 +++++++++---------- test/runtests.jl | 1 - 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f85d10f05..6035c4105 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -138,6 +138,7 @@ function build_model_info(input_expr) modelinfo = Dict( :allargs_exprs => [], :allargs_syms => [], + :allargs_defaults => [], :modeldef => modeldef, ) return modelinfo @@ -146,17 +147,18 @@ function build_model_info(input_expr) # Extract the positional and keyword arguments from the model definition. allargs = vcat(modeldef[:args], modeldef[:kwargs]) - # Split the argument expressions and the default values. - allargs_exprs_defaults = map(allargs) do arg - MacroTools.@match arg begin + # Split the argument expressions and the default values, by unzipping allargs, taking care of the + # empty case + allargs_exprs, allargs_defaults = foldl(allargs, init=([], [])) do (ae, ad), arg + (expr, default) = MacroTools.@match arg begin (x_ = val_) => (x, val) x_ => (x, NO_DEFAULT) end + push!(ae, expr) + push!(ad, default) + ae, ad end - - # Extract the expressions of the arguments, without default values. - allargs_exprs = first.(allargs_exprs_defaults) - + # Extract the names of the arguments. allargs_syms = map(allargs_exprs) do arg MacroTools.@match arg begin @@ -169,6 +171,7 @@ function build_model_info(input_expr) modelinfo = Dict( :allargs_exprs => allargs_exprs, :allargs_syms => allargs_syms, + :allargs_defaults => allargs_defaults, :modeldef => modeldef, ) @@ -377,10 +380,13 @@ function build_output(modelinfo, linenumbernode) # Extract the named tuple expression of all arguments allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]] - allargs_wrapped = [ - x ∈ modelinfo[:obsnames] ? :($(DynamicPPL.Variable)($x)) : :($(DynamicPPL.Constant)($x)) - for x in modelinfo[:allargs_syms] - ] + allargs_wrapped = map(modelinfo[:allargs_syms], modelinfo[:allargs_defaults]) do x, d + if x ∈ modelinfo[:obsnames] + :($(DynamicPPL.Variable)($x, $d)) + else + :($(DynamicPPL.Constant)($x, $d)) + end + end allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) diff --git a/src/model.jl b/src/model.jl index 614342953..38fb174e7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,25 +1,21 @@ """ - abstract type Argument{T,isdefault} end + abstract type Argument{T,Tdefault} end Parametric wrapper type for model arguments. """ -abstract type Argument{T,isdefault} end +abstract type Argument{T,Tdefault} end -struct Variable{T,isdefault} <: Argument{T,isdefault} +struct Variable{T,Tdefault} <: Argument{T,Tdefault} value::T + default::Tdefault end -Variable{isdefault}(x) where {isdefault} = Variable{typeof(x), false}(x) -Variable(x) = Variable{false}(x) - -struct Constant{T,isdefault} <: Argument{T,isdefault} +struct Constant{T,Tdefault} <: Argument{T,Tdefault} value::T + default::Tdefault end -Constant{isdefault}(x) where {isdefault} = Constant{typeof(x), false}(x) -Constant(x) = Constant{false}(x) - """ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram @@ -172,6 +168,9 @@ getarguments(model::Model) = map(arg -> arg.value, model.arguments) return :(NamedTuple{$filtered_argnames}(($(values...),))) end +hasdefault(model::Model, argname::Symbol) = getdefault(model, argname) !== NO_DEFAULT +getdefault(model::Model, argname::Symbol) = getproperty(model.arguments, argname).default + """ isobservation(vn, model) diff --git a/test/runtests.jl b/test/runtests.jl index 72f8555f7..98be45b2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,7 +37,6 @@ include("test_util.jl") include("model.jl") include("sampler.jl") # include("prob_macro.jl") - @warn "Prob macro tests turned off!!!!" include("independence.jl") include("distribution_wrappers.jl") include("contexts.jl") From 32184f2d6dcbedd030fcbe4f0e1ab6d8c867b650 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sat, 3 Jul 2021 17:22:47 +0200 Subject: [PATCH 10/13] Make prob macro code work --- src/compiler.jl | 4 +-- src/model.jl | 8 ++++++ src/prob_macro.jl | 66 +++++++++++++++++++++++------------------------ test/runtests.jl | 2 +- 4 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6035c4105..5c63bd062 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -147,8 +147,8 @@ function build_model_info(input_expr) # Extract the positional and keyword arguments from the model definition. allargs = vcat(modeldef[:args], modeldef[:kwargs]) - # Split the argument expressions and the default values, by unzipping allargs, taking care of the - # empty case + # Split the argument expressions and the default values, by unzipping allargs, taking care of + # the empty case allargs_exprs, allargs_defaults = foldl(allargs, init=([], [])) do (ae, ad), arg (expr, default) = MacroTools.@match arg begin (x_ = val_) => (x, val) diff --git a/src/model.jl b/src/model.jl index 38fb174e7..385c645bd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -11,11 +11,15 @@ struct Variable{T,Tdefault} <: Argument{T,Tdefault} default::Tdefault end +Variable(x) = Variable(x, NO_DEFAULT) + struct Constant{T,Tdefault} <: Argument{T,Tdefault} value::T default::Tdefault end +Constant(x) = Constant(x, NO_DEFAULT) + """ struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram @@ -168,9 +172,13 @@ getarguments(model::Model) = map(arg -> arg.value, model.arguments) return :(NamedTuple{$filtered_argnames}(($(values...),))) end +hasargument(model::Model, argname::Symbol) = isdefined(model.arguments, argname) +getargument(model::Model, argname::Symbol) = getproperty(model.arguments, argname).value + hasdefault(model::Model, argname::Symbol) = getdefault(model, argname) !== NO_DEFAULT getdefault(model::Model, argname::Symbol) = getproperty(model.arguments, argname).default + """ isobservation(vn, model) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..09c50423e 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -1,3 +1,12 @@ +function _isdefault(argnames, TArgs, arg) + ix = findfirst(==(arg), argnames) + if !ismissing(ix) + return !(TArgs.parameters[ix] <: Argument{<:Any, NoDefault}) + else + return false + end +end + macro logprob_str(str) expr1, expr2 = get_exprs(str) return :(logprob($(esc(expr1)), $(esc(expr2)))) @@ -53,11 +62,12 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names else vi = nothing end - defaults = model.defaults - @assert all(getargnames(model)) do arg + @assert all(getargumentnames(model)) do arg isdefined(ntl, arg) || isdefined(ntr, arg) || - isdefined(defaults, arg) && getfield(defaults, arg) !== missing + (hasargument(model, arg) && + hasdefault(model, arg) && + !ismissing(getdefault(model, arg))) end return Val(:likelihood), model, vi else @@ -78,9 +88,8 @@ end function probtype( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} - defaults = model.defaults + model::Model{_F,argnames}, +) where {leftnames,rightnames,argnames,_F} prior_rhs = all( n -> n in (:model, :varinfo) || n in argnames && getfield(right, n) !== missing, rightnames, @@ -90,8 +99,8 @@ function probtype( return getfield(left, arg) elseif arg in rightnames return getfield(right, arg) - elseif arg in defaultnames - return getfield(defaults, arg) + elseif hasargument(model, arg) && hasdefault(model, arg) + return getdefault(model, arg) else return nothing end @@ -153,33 +162,28 @@ end @generated function make_prior_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} + model::Model{_F,argnames,TArgs}, +) where {leftnames,rightnames,argnames,TArgs,_F} argvals = [] - missings = [] warnings = [] for argname in argnames if argname in leftnames - push!(argvals, :(deepcopy(left.$argname))) - push!(missings, argname) + push!(argvals, :(Constant(deepcopy(left.$argname)))) elseif argname in rightnames - push!(argvals, :(right.$argname)) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) + push!(argvals, :(Variable(right.$argname))) + elseif _isdefault(argnames, TArgs, argname) + push!(argvals, :(Variable(model.arguments.$argname.default))) else push!(warnings, :(@warn($(warn_msg(argname))))) - push!(argvals, :(nothing)) + push!(argvals, :(Variable(nothing))) end end # `args` is inserted as properly typed NamedTuple expression; - # `missings` is splatted into a tuple at compile time and inserted as literal return quote $(warnings...) - Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - ) + Model(model.name, model.evaluator, $(to_namedtuple_expr(argnames, argvals))) end end @@ -213,19 +217,17 @@ end @generated function make_likelihood_model( left::NamedTuple{leftnames}, right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} + model::Model{_F,argnames,TArgs}, +) where {leftnames,rightnames,argnames,TArgs,_F} argvals = [] - missings = [] - + for argname in argnames if argname in leftnames - push!(argvals, :(left.$argname)) + push!(argvals, :(Variable(left.$argname))) elseif argname in rightnames - push!(argvals, :(right.$argname)) - push!(missings, argname) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) + push!(argvals, :(Constant(right.$argname))) + elseif _isdefault(argnames, TArgs, argname) + push!(argvals, :(Variable(model.arguments.$argname.default))) else throw( "This point should not be reached. Please open an issue in the DynamicPPL.jl repository.", @@ -235,7 +237,5 @@ end # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal - return :(Model{$(Tuple(missings))}( - model.name, model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - )) + return :(Model(model.name, model.evaluator, $(to_namedtuple_expr(argnames, argvals)))) end diff --git a/test/runtests.jl b/test/runtests.jl index 98be45b2a..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,7 +36,7 @@ include("test_util.jl") include("varinfo.jl") include("model.jl") include("sampler.jl") - # include("prob_macro.jl") + include("prob_macro.jl") include("independence.jl") include("distribution_wrappers.jl") include("contexts.jl") From f070f82f0449791a210947faa4af2cea9ea82281 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sun, 4 Jul 2021 11:47:57 +0200 Subject: [PATCH 11/13] Implement some JuliaFormatter suggestions --- src/compiler.jl | 2 +- src/model.jl | 2 ++ src/prob_macro.jl | 11 +++++------ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5c63bd062..5f691bfcc 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -149,7 +149,7 @@ function build_model_info(input_expr) # Split the argument expressions and the default values, by unzipping allargs, taking care of # the empty case - allargs_exprs, allargs_defaults = foldl(allargs, init=([], [])) do (ae, ad), arg + allargs_exprs, allargs_defaults = foldl(allargs; init=([], [])) do (ae, ad), arg (expr, default) = MacroTools.@match arg begin (x_ = val_) => (x, val) x_ => (x, NO_DEFAULT) diff --git a/src/model.jl b/src/model.jl index 385c645bd..9d00d078e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -44,10 +44,12 @@ function Base.show(io::IO, ::MIME"text/plain", model::Model) println(io) print(io, " observed variables ") join(io, getobservedvariables(model), ", ") + return nothing end function Base.show(io::IO, model::Model) println(io, "$(model.name)$(getarguments(model))") + return nothing end diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 09c50423e..1fdc1da38 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -1,7 +1,7 @@ function _isdefault(argnames, TArgs, arg) ix = findfirst(==(arg), argnames) if !ismissing(ix) - return !(TArgs.parameters[ix] <: Argument{<:Any, NoDefault}) + return !(TArgs.parameters[ix] <: Argument{<:Any,NoDefault}) else return false end @@ -63,11 +63,10 @@ function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {names vi = nothing end @assert all(getargumentnames(model)) do arg - isdefined(ntl, arg) || - isdefined(ntr, arg) || - (hasargument(model, arg) && - hasdefault(model, arg) && - !ismissing(getdefault(model, arg))) + isdefined(ntl, arg) || isdefined(ntr, arg) || + hasargument(model, arg) && + hasdefault(model, arg) && + !ismissing(getdefault(model, arg)) end return Val(:likelihood), model, vi else From 1b3291873c9414c02abe0f02ab231c6411ab5484 Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Sun, 4 Jul 2021 12:04:37 +0200 Subject: [PATCH 12/13] Basic extraction of internal variables --- src/compiler.jl | 9 +++++++++ src/model.jl | 20 ++++++++++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5f691bfcc..7b6b3184f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -110,6 +110,7 @@ function model(mod, linenumbernode, expr, warn) # extract observations from that modelinfo[:obsnames] = modelinfo[:allargs_syms] ∩ modelinfo[:varnames] + modelinfo[:latentnames] = setdiff(modelinfo[:varnames], modelinfo[:allargs_syms]) return build_output(modelinfo, linenumbernode) end @@ -389,6 +390,12 @@ function build_output(modelinfo, linenumbernode) end allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)] allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames) + + internals_newnames = [gensym(x) for x in modelinfo[:latentnames]] + internals_decls = map(internals_newnames) do name + :($name = $(DynamicPPL.Variable)(missing)) + end + internals_namedtuple = to_namedtuple_expr(modelinfo[:latentnames], internals_newnames) # Update the function body of the user-specified model. # We use a name for the anonymous evaluator that does not conflict with other variables. @@ -401,10 +408,12 @@ function build_output(modelinfo, linenumbernode) $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) $(allargs_decls...) + $(internals_decls...) return $(DynamicPPL.Model)( $(QuoteNode(modeldef[:name])), $evaluator, $allargs_namedtuple, + $internals_namedtuple, ) end diff --git a/src/model.jl b/src/model.jl index 9d00d078e..70bd238fd 100644 --- a/src/model.jl +++ b/src/model.jl @@ -30,20 +30,24 @@ Constant(x) = Constant(x, NO_DEFAULT) A `Model` struct with model evaluation function of type `F`, and arguments `arguments`. """ -struct Model{F, argumentnames, Targs} <: AbstractProbabilisticProgram +struct Model{F, argumentnames, Targs, internalnames, Tinternals} <: AbstractProbabilisticProgram name::Symbol # code::Expr evaluator::F arguments::NamedTuple{argumentnames,Targs} + internal_variables::NamedTuple{internalnames,Tinternals} end function Base.show(io::IO, ::MIME"text/plain", model::Model) println(io, "Model ", model.name, " given") - print(io, " constants ") + print(io, " constants: ") join(io, getconstants(model), ", ") println(io) - print(io, " observed variables ") + print(io, " observed variables: ") join(io, getobservedvariables(model), ", ") + println(io) + print(io, " latent variables: ") + join(io, getlatentvariables(model), ", ") return nothing end @@ -166,10 +170,10 @@ the types of arguments (constant, variable, default) by passing an `Argument` su """ getarguments(model::Model) = map(arg -> arg.value, model.arguments) @generated function getarguments( - model::Model{<:Any,argnames,TArgs}, + model::Model{<:Any,argnames,Targs}, ::Type{T} -) where {argnames,TArgs,T} - filtered_argnames = _getargumentnames(argnames, TArgs, T) +) where {argnames,Targs,T} + filtered_argnames = _getargumentnames(argnames, Targs, T) values = [:(model.arguments.$arg.value) for arg in filtered_argnames] return :(NamedTuple{$filtered_argnames}(($(values...),))) end @@ -221,6 +225,10 @@ function getobservedvariables(model::Model) return observed_variables end +function getlatentvariables(model::Model) + return [VarName{var}() for var in keys(model.internal_variables)] +end + getconstants(model::Model) = VarName[VarName{c}() for c in getargumentnames(model, Constant)] """ From 9df5eb22c4e9f816f7e0fdab4b844729e7aa0d5a Mon Sep 17 00:00:00 2001 From: Philipp Gabler Date: Wed, 7 Jul 2021 22:01:56 +0200 Subject: [PATCH 13/13] Implement proper separation of latent variables --- src/model.jl | 31 ++++++++++++++++++++++--------- src/prob_macro.jl | 11 +++++++++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/model.jl b/src/model.jl index 70bd238fd..976e54cf8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -201,8 +201,11 @@ isobservation(::VarName, ::Constant) = false isobservation(vn::VarName, obs::Variable{Missing}) = false isobservation(vn::VarName, obs::Variable) = !ismissing(_getindex(obs.value, vn.indexing)) -function getobservedvariables(model::Model) +function getvariables_separated(model::Model) observed_variables = VarName[] + latent_variables = VarName[] + + # separate the argument variables into observed and latend for (var, value) in pairs(getarguments(model, Variable)) if value isa AbstractArray all_indices = CartesianIndices(value) @@ -212,23 +215,33 @@ function getobservedvariables(model::Model) push!(observed_variables, VarName{var}()) else # mixed case -- indexed variables in both categories - observed_indices = setdiff(all_indices, missing_indices) - for ix in observed_indices - push!(observed_variables, VarName{var}((Tuple(ix),))) + for ix in all_indices + complete_name = VarName{var}((Tuple(ix),)) + if ix in missing_indices + push!(latent_variables, complete_name) + else + push!(observed_variables, complete_name) + end end end else complete_name = VarName{var}() - !ismissing(value) && push!(observed_variables, complete_name) + if ismissing(value) + push!(latent_variables, complete_name) + else + push!(observed_variables, complete_name) + end end end - return observed_variables -end -function getlatentvariables(model::Model) - return [VarName{var}() for var in keys(model.internal_variables)] + # add purely internal variables as latent + append!(latent_variables, (VarName{var}() for var in keys(model.internal_variables))) + + return (observed_variables, latent_variables) end +getobservedvariables(model::Model) = getvariables_separated(model)[1] +getlatentvariables(model::Model) = getvariables_separated(model)[2] getconstants(model::Model) = VarName[VarName{c}() for c in getargumentnames(model, Constant)] """ diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 1fdc1da38..f25f19d83 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -182,7 +182,10 @@ end # `args` is inserted as properly typed NamedTuple expression; return quote $(warnings...) - Model(model.name, model.evaluator, $(to_namedtuple_expr(argnames, argvals))) + Model(model.name, + model.evaluator, + $(to_namedtuple_expr(argnames, argvals)), + model.internal_variables) end end @@ -236,5 +239,9 @@ end # `args` is inserted as properly typed NamedTuple expression; # `missings` is splatted into a tuple at compile time and inserted as literal - return :(Model(model.name, model.evaluator, $(to_namedtuple_expr(argnames, argvals)))) + return :(Model( + model.name, + model.evaluator, + $(to_namedtuple_expr(argnames, argvals)), + model.internal_variables)) end