-
Notifications
You must be signed in to change notification settings - Fork 230
Reduce usage of sampler #1936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce usage of sampler #1936
Conversation
src/Turing.jl
Outdated
| # Convenient for end-user. | ||
| function LogDensityfunction( | ||
| model::DynamicPPL.Model, | ||
| varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.VarInfo(model), | ||
| context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext(), | ||
| ) | ||
| return LogDensityFunction(model, varinfo, context) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be added at this point? IIRC it's not exported and not suggested that users use it (I think it's not even documented).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO yeah; in fact, I kind of want to export it.
Now that we're adopting the LogDensityProblems.jl interface more aggressively, I also find myself using it very often. I also very much like it because most users will find the "full" description necessary to run the model, i.e. model + varinfo + context too "scary". If we can say "if you want to play around with your model, you can convert it to a simple thing which is compatible with LogDensityProblems", I think that's quite nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it as well and I think we want to expose it somehow. My main questions are just whether the right time to do is during the refactoring and whether we want to expose it via a different API. For instance, we could also use something like
logjoint(model, ::Type{T}=VarInfo) where {T<:AbstractVarInfo} = LogDensityFunction(model, T(model), DefaultContext())
logprior(model, ::Type{T}=VarInfo) where {T<:AbstractVarInfo} = LogDensityFunction(model, T(model), PriorContext())
loglikelihood(model, ::Type{T}=VarInfo) where {T<:AbstractVarInfo} = LogDensityFunction(model, T(model), LikelihoodContext())and keep LogDensityFunction an internal detail. Maybe that would be even easier for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I do like this interface, I feel like we still need to inform the end-user about it and might as well export it given that they'll be seeing it "a lot"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if they have to know about LogDensityFunctions at all if logjoint etc would be the API. The only relevant point for thr uaers would be (and should be documented) that these functions return something that follows the LogDensityProblems interface. Fields, names etc. should not matter for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constructor is actually quite a bit annoying as it doesn't let you re-use an existing AbstractVarInfo 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was just the first thing I thought of, maybe there are better approaches. But maybe that's also an indication that a separate PR and discussion about the API could be uaeful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could always add make_logjoint, etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with a separate PR, but worried it's going to delay exposing the LogDensityProblems.jl-functionality to the users. But I'll make separate PR for now, and if it takes too long we can just export LogDensityFunction 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or something like logjointof etc., similar to (log)densityof.
On a second thought though, I think generally (and particularly when dealing with other model type) it would be nice if there would be some curried form of logjoint etc. that gives you a function that you can evaluate for different variables and that would be consistent with the exiating logjoint etc. And generally there one would only want to provide the variables when evaluating the curried form. Maybe one could add a keyword argument for the LogDensityFunction case (similar to e.g. how one can adjust behaviour of AD backends in LogDensityProblemsAD by providing an example of the values that log density function will be evaluated with). But to summarize, to me the desire to provide VarInfo(model) when constructing the LogDensityFunction seems very specific to this case (and maybe not something many users would want to do?) but not something you would want to do for every kind of models, and hence maybe it should not be first-class part of a generic API but a bit more specific to this case such as maybe a keyword argument.
| getADbackend(spl::Sampler) = getADbackend(spl.alg) | ||
| getADbackend(::SampleFromPrior) = ADBackend()() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these still needed? Or could we remove them and specialize getADbackend(::SamplingContext{<:Sampler}) and getADbackend(::SamplingContext{SampleFromPrior})? I think mentally it's easier if getADbackend only operates on contexts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I can try doing that! I was worried this would break too many tests and discourage me from completing the PR, but now that the tests are passing I can see if it's worth it.
Pull Request Test Coverage Report for Build 3935194490
💛 - Coveralls |
Codecov ReportBase: 0.00% // Head: 0.00% // No change to project coverage 👍
Additional details and impacted files@@ Coverage Diff @@
## master #1936 +/- ##
======================================
Coverage 0.00% 0.00%
======================================
Files 21 21
Lines 1432 1422 -10
======================================
+ Misses 1432 1422 -10
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Co-authored-by: David Widmann <[email protected]>
yebai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to phase out filtering VarInfo variables by sampler, the new context mechanism is strictly better, I think.
| getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()() | ||
| getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx)) | ||
|
|
||
| function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most codes in this ad file can be safely transferred to DynamicPPL too. One benefit is that users of DynamicPPL get the gradient feature of LogDensityFunction without loading the Turing package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, though I think I'll leave it here for the time being just because I want to keep the PRs simple.
…ngLang/Turing.jl into torfjelde/less-sampler-more-context
Co-authored-by: Hong Ge <[email protected]>
|
|
||
| # Make a new transition. | ||
| densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) | ||
| densitymodel = AMH.DensityModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it easier to use a LogDensityModel? Then you don't need Fix1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh true!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, does that work though? AFAIK we don't support wrapping something that implements LogDensityModel in AMH.DensityModel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't want to wrap it in a DensityModel, do we? step seem to support LogDensityModels: https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/mh-core.jl#L74-L113
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This didn't work unfortunately. It seems we're hitting https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L69 instead of https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L72. Also do you know the reasoning behind defining one for both DensityOrLogDensityModel and LogDensityModel? Why isn't it DensityModel and LogDensityModel separately?
EDIT: Looking at git-blame, it seems like I tried to do a replace and was a bit too aggressive 😅
| import AdvancedVI | ||
| using DynamicPPL: DynamicPPL, LogDensityFunction | ||
| import DynamicPPL: getspace, NoDist, NamedDist | ||
| import LogDensityProblems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can be removed now?
| import LogDensityProblems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need it for all the optimation related stuff.
This reverts commit 596b0b2.
Long-term we want to move away from using
DynamicPPL.Samplerto decide which variables are being sampled, etc.This PR reduces the explicit accessing and some usage of
DynamicPPL.Sampleras one step towards this.In addition, due to how
DynamicPPL.AbstractSampleris handled differently forSimpleVarInfoandVarInfo, the current design ofLogDensityFunction(which requires a sampler) means that we cannot useSimpleVarinfoin this regard (SimpleVarInfo+AbstractSamplerresults in always sampling variables, hence making it useless for computing gradients, etc.).