-
Notifications
You must be signed in to change notification settings - Fork 230
Reduce usage of sampler #1936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce usage of sampler #1936
Changes from all commits
affc71d
e5ff14f
cf295f1
2024348
4e38fd0
c986828
67fda5e
8833092
d43d6fb
093cd40
260b498
3505c58
21e922b
65bb1e2
edf172a
6e5d304
88d53d7
b19f0aa
06875d5
6cee95e
5ec0a0d
596b0b2
9ebdf77
b8542e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,13 +77,18 @@ Find the autodifferentiation backend of the algorithm `alg`. | |
| """ | ||
| getADbackend(spl::Sampler) = getADbackend(spl.alg) | ||
| getADbackend(::SampleFromPrior) = ADBackend()() | ||
|
Comment on lines
78
to
79
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these still needed? Or could we remove them and specialize
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I can try doing that! I was worried this would break too many tests and discourage me from completing the PR, but now that the tests are passing I can see if it's worth it. |
||
| getADbackend(ctx::DynamicPPL.SamplingContext) = getADbackend(ctx.sampler) | ||
| getADbackend(ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.NodeTrait(ctx), ctx) | ||
|
|
||
| getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()() | ||
| getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx)) | ||
|
|
||
| function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most codes in this
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, though I think I'll leave it here for the time being just because I want to keep the PRs simple. |
||
| return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.sampler), ℓ) | ||
| return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.context), ℓ) | ||
| end | ||
|
|
||
| function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction) | ||
| θ = ℓ.varinfo[ℓ.sampler] | ||
| θ = DynamicPPL.getparams(ℓ) | ||
| f = Base.Fix1(LogDensityProblems.logdensity, ℓ) | ||
|
|
||
| # Define configuration for ForwardDiff. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -246,11 +246,11 @@ A log density function for the MH sampler. | |
|
|
||
| This variant uses the `set_namedtuple!` function to update the `VarInfo`. | ||
| """ | ||
| const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext} | ||
| const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.SamplingContext{<:S}} | ||
|
|
||
| function (f::MHLogDensityFunction)(x::NamedTuple) | ||
| function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple) | ||
| # TODO: Make this work with immutable `f.varinfo` too. | ||
| sampler = f.sampler | ||
| sampler = DynamicPPL.getsampler(f) | ||
| vi = f.varinfo | ||
|
|
||
| x_old, lj_old = vi[sampler], getlogp(vi) | ||
|
|
@@ -374,7 +374,9 @@ function propose!!( | |
| prev_trans = AMH.Transition(vt, getlogp(vi)) | ||
|
|
||
| # Make a new transition. | ||
| densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) | ||
| densitymodel = AMH.DensityModel( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't it easier to use a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh true!
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, does that work though? AFAIK we don't support wrapping something that implements
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't want to wrap it in a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This didn't work unfortunately. It seems we're hitting https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L69 instead of https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L72. Also do you know the reasoning behind defining one for both EDIT: Looking at git-blame, it seems like I tried to do a replace and was a bit too aggressive 😅 |
||
| Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl))) | ||
| ) | ||
| trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) | ||
|
|
||
| # TODO: Make this compatible with immutable `VarInfo`. | ||
|
|
@@ -400,7 +402,9 @@ function propose!!( | |
| prev_trans = AMH.Transition(vals, getlogp(vi)) | ||
|
|
||
| # Make a new transition. | ||
| densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) | ||
| densitymodel = AMH.DensityModel( | ||
| Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl))) | ||
| ) | ||
| trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans) | ||
|
|
||
| return setlogp!!(DynamicPPL.unflatten(vi, spl, trans.params), trans.lp) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this can be removed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need it for all the optimation related stuff.