-
Notifications
You must be signed in to change notification settings - Fork 37
Refactor model for the AbstractPPL way #244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/model.jl
Outdated
| struct Model{F, parameternames, observationnames, Tparams, Tobs} <: AbstractProbabilisticProgram | ||
| name::Symbol | ||
| f::F | ||
| args::NamedTuple{argnames,Targs} | ||
| defaults::NamedTuple{defaultnames,Tdefaults} | ||
| evaluator::F | ||
| parameters::NamedTuple{parameternames,Tparams} | ||
| observations::NamedTuple{observationnames,Tobs} | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are breaking changes. Is it possible and worth the effort to introduce the new supertype (which is not breaking itself) without changing the type parameters and fields of Model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also a bit skeptical about encoding this in Model. Would it maybe be better to just have a ContextualModel where we couple a AbstractContext to a Model instance, e.g. we make a ConditionContext, and condition(model, ...) just returns the a ContextualModel{<:ConditionContext, <:Model} or w/e. I'm guessing the motivation behind stuff like condition being "structural" (in the sense that it "changes" the model itself rather than just wrapping it) is described somewhere else, so I'll raise this more "general" comment there.
But for this PR I agree with the above: an abstract Model so that we can experiment without making breaking changes is a good idea:)
EDIT: The motivation behind the above is that observations should never have to be specified at construction of the model IMO. Arguments of the model can be observations or just general "hyperparameters" of a model. condition should then specify what's observed and what's not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea was that the @model macro builds a model constructor function that will then automatically construct the conditioned model, based on the same principles that we have now with inargnames and missings. And it should do this in a way that behaves like condition(generative_model, observations), where the generative model is what the evaluator function basically expresses.
Now, I have thought about separating those two cases, so that the evaluator function goes into a GenerativeModel, and the model constructor then applies condition, but I realized that one type should suffice, since a generative model is just a model conditioned on an empty set of observations.
| julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition of missings | ||
| Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) | ||
| TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason to shorten the docstring and remove the example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it will look differently and I haven't yet updated it :P
| """ | ||
| Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) | ||
| Create a model of name `name` with evaluation function `f` and missing arguments | ||
| overwritten by `missings`. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring should not be removed. It was added on purpose due to a bug in Julia (otherwise it won't show up in ? Model IIRC). I can look up the relevant PR.
src/model.jl
Outdated
| Macro with more convenient syntax for declaring `Model` types with observations (similar to the | ||
| `Base.@NamedTuple` macro). The observations to the parameters part of the braces: | ||
| `@ConditionedModel{; x::Int, y}`. Type annotations can be omitted, in which case the type is | ||
| defaulted to `Any`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example where this macro is needed? Currently users never explicitly construct a Model instance.
src/model.jl
Outdated
| Type alias for models without observations. | ||
| """ | ||
| const GenerativeModel{F, parameternames, Tparams} = @ConditionedModel{F, parameternames, Tparams;} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, do you have an example where it is needed?
src/model.jl
Outdated
| @generated function _evaluate(rng, model::Model{_F,argnames}, varinfo, sampler, context) where {_F,argnames} | ||
| unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] | ||
| return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) | ||
| function _evaluate(rng, model::Model, varinfo, sampler, context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there was a reason for why this was a generated function before. It would be good to benchmark it before making it a regular function.
src/model.jl
Outdated
| """ | ||
| getargnames(model::Model{_F,argnames}) where {argnames,_F} = argnames | ||
|
|
||
| getparameternames(model::Model{_F,parameternames}) where {_F,parameternames} = parameternames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on whether we plan a breaking release or not, these should be deprecated (if exported). In any case it might be good to deprecate them and make a non-breaking release such that it is easier to update Turing and other downstream packages (if there are any 😄).
|
|
||
| # See also [`logjoint`](@ref) and [`loglikelihood`](@ref). | ||
| # """ | ||
| # function logprior(model::Model, varinfo::AbstractVarInfo) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to still be able to evaluate log prior and log likelihoods.
src/compiler.jl
Outdated
|
|
||
| if left isa Symbol || left isa Expr | ||
| @gensym out vn inds isassumption | ||
| @gensym out vn inds isobservation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should benchmark all changes here, similar to @torfjelde's PR a while ago. There were some surprising performance regressions that had to be fixed.
Yes, this seems problematic. If one would define a specialized evaluation function that distinguishes parameters and observations already when constructing the |
| @@ -1,44 +1,12 @@ | |||
| const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " * | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " * | |
| const DISTMSG = | |
| "Right-hand side of a ~ must be subtype of Distribution or a vector of " * |
|
|
||
| # failsafe: a literal is never an assumption | ||
| isassumption(expr) = :(false) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| return unwrap_right_left_vns(right, left, vns) | ||
| end | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| # Generate main body | ||
| modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | ||
| # Generate main body and find all variable symbols | ||
| modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| modelinfo[:body], modelinfo[:varnames] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn) | |
| modelinfo[:body], modelinfo[:varnames] = generate_mainbody( | |
| mod, modelinfo[:modeldef][:body], warn | |
| ) |
| # extract parameters and observations from that | ||
| modelinfo[:paramnames] = filter(x -> x ∉ modelinfo[:varnames], modelinfo[:allargs_syms]) | ||
| modelinfo[:obsnames] = setdiff(modelinfo[:allargs_syms], modelinfo[:paramnames]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| @generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} | ||
| obs_indices = filter(i -> Targs.parameters[i] <: Observation, eachindex(Targs.parameters)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| @generated function getobservationnames(model::Model{_F,argnames,Targs}) where {_F,argnames,Targs} | |
| obs_indices = filter(i -> Targs.parameters[i] <: Observation, eachindex(Targs.parameters)) | |
| @generated function getobservationnames( | |
| model::Model{_F,argnames,Targs} | |
| ) where {_F,argnames,Targs} | |
| obs_indices = filter( | |
| i -> Targs.parameters[i] <: Observation, eachindex(Targs.parameters) | |
| ) |
| logprior(model::Model, varinfo::AbstractVarInfo) | ||
| Return the log prior probability of variables `varinfo` for the probabilistic `model`. | ||
| function AbstractPPL.decondition(model::Model, name = Symbol(model.name, "_joint")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| function AbstractPPL.decondition(model::Model, name = Symbol(model.name, "_joint")) | |
| function AbstractPPL.decondition(model::Model, name=Symbol(model.name, "_joint")) |
| function AbstractPPL.condition(model::Model, observations, name = Symbol(model.name, "_cond")) | ||
| return Model(name, model.evaluator, model.parameters, merge(model.observations, observations)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| function AbstractPPL.condition(model::Model, observations, name = Symbol(model.name, "_cond")) | |
| return Model(name, model.evaluator, model.parameters, merge(model.observations, observations)) | |
| function AbstractPPL.condition(model::Model, observations, name=Symbol(model.name, "_cond")) | |
| return Model( | |
| name, model.evaluator, model.parameters, merge(model.observations, observations) | |
| ) |
|
|
||
| """ | ||
| loglikelihood(model::Model, varinfo::AbstractVarInfo) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| model(varinfo, LikelihoodContext()) | ||
| return getlogp(varinfo) | ||
| end | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| # return :(Modelngs}(name, f, args, defaults)) | ||
| # end | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| vn::VarName{s}, | ||
| model::Model{_F, argnames} | ||
| ) where {s, _F, argnames} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| vn::VarName{s}, | |
| model::Model{_F, argnames} | |
| ) where {s, _F, argnames} | |
| vn::VarName{s}, model::Model{_F,argnames} | |
| ) where {s,_F,argnames} |
| isobservation(vn::VarName, obs::Observation) = !ismissing(_getindex(obs, vn.indexing)) | ||
| isobservation(vn::VarName, obs::Observation{Missing}) = false | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
|
Closed in favour of #268 for now, which will allow for a nicer way of implementing |
Now, here I have begun refactorings that go in hand with introducing the AbstractPPL abstractions. For now it is mostly changing what goes in
Model, and how the compiler macro constructs it.This is not in a usable state right now; but it's less boilerplate, unless I forgot something important in the stuff I removed.
Since the structural part AbstractPPL is (in the currently proposed idea) based on the
conditionanddeconditionoperations, I decided to try a different approach than the current combination of "arguments" and "missings". Construction of this stuff should, I think, work; it remains to adapt the model macro and make evaluation work.Couple of questions:
Putting
parametersandobservationsin the model macro requires distinguishing them somehow. As a first option, the easiest way to do that would be Soss-style: make@modelalways construct aGenerativeModel, and useconditionlater. Which is totally breaking what people expect from the previous interface, though.Second alternative would be to mark either parameters or observations syntactically. I kind of like this, but have heard voices against it; also, it changes from one non-standard evaluation semantics to another.
Third option, most advanced but perhaps intuitive one, would be to do this via the same kind of process that
TypedVarInfouses, by running one evaluation.The
parameters/observationssplit also requires them to be separated in the evaluator function. Is it a problem to just pass two named tuples instead of the original argument form? Or was that designed for some reasons I don't know?Also, how does this interact with the
::Type{T} = Float64hack?