|
2 | 2 | ### Particle Filtering and Particle MCMC Samplers. |
3 | 3 | ### |
4 | 4 |
|
| 5 | +### AdvancedPS models and interface |
| 6 | + |
| 7 | +struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: |
| 8 | + AdvancedPS.AbstractGenericModel |
| 9 | + model::M |
| 10 | + sampler::S |
| 11 | + varinfo::V |
| 12 | + evaluator::E |
| 13 | +end |
| 14 | + |
| 15 | +function TracedModel( |
| 16 | + model::Model, |
| 17 | + sampler::AbstractSampler, |
| 18 | + varinfo::AbstractVarInfo, |
| 19 | + rng::Random.AbstractRNG, |
| 20 | +) |
| 21 | + context = SamplingContext(rng, sampler, DefaultContext()) |
| 22 | + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) |
| 23 | + if kwargs !== nothing && !isempty(kwargs) |
| 24 | + error( |
| 25 | + "Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", |
| 26 | + ) |
| 27 | + end |
| 28 | + return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( |
| 29 | + model, sampler, varinfo, (model.f, args...) |
| 30 | + ) |
| 31 | +end |
| 32 | + |
| 33 | +function AdvancedPS.advance!( |
| 34 | + trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false |
| 35 | +) |
| 36 | + # Make sure we load/reset the rng in the new replaying mechanism |
| 37 | + DynamicPPL.increment_num_produce!(trace.model.f.varinfo) |
| 38 | + isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) |
| 39 | + score = consume(trace.model.ctask) |
| 40 | + if score === nothing |
| 41 | + return nothing |
| 42 | + else |
| 43 | + return score + DynamicPPL.getlogp(trace.model.f.varinfo) |
| 44 | + end |
| 45 | +end |
| 46 | + |
| 47 | +function AdvancedPS.delete_retained!(trace::TracedModel) |
| 48 | + DynamicPPL.set_retained_vns_del!(trace.varinfo) |
| 49 | + return trace |
| 50 | +end |
| 51 | + |
| 52 | +function AdvancedPS.reset_model(trace::TracedModel) |
| 53 | + DynamicPPL.reset_num_produce!(trace.varinfo) |
| 54 | + return trace |
| 55 | +end |
| 56 | + |
| 57 | +function AdvancedPS.reset_logprob!(trace::TracedModel) |
| 58 | + DynamicPPL.resetlogp!!(trace.model.varinfo) |
| 59 | + return trace |
| 60 | +end |
| 61 | + |
| 62 | +function AdvancedPS.update_rng!( |
| 63 | + trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} |
| 64 | +) |
| 65 | + # Extract the `args`. |
| 66 | + args = trace.model.ctask.args |
| 67 | + # From `args`, extract the `SamplingContext`, which contains the RNG. |
| 68 | + sampling_context = args[3] |
| 69 | + rng = sampling_context.rng |
| 70 | + trace.rng = rng |
| 71 | + return trace |
| 72 | +end |
| 73 | + |
| 74 | +function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? |
| 75 | + return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) |
| 76 | +end |
| 77 | + |
5 | 78 | #### |
6 | 79 | #### Generic Sequential Monte Carlo sampler. |
7 | 80 | #### |
@@ -408,7 +481,7 @@ function AdvancedPS.Trace( |
408 | 481 | newvarinfo = deepcopy(varinfo) |
409 | 482 | DynamicPPL.reset_num_produce!(newvarinfo) |
410 | 483 |
|
411 | | - tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng) |
| 484 | + tmodel = TracedModel(model, sampler, newvarinfo, rng) |
412 | 485 | newtrace = AdvancedPS.Trace(tmodel, rng) |
413 | 486 | AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) |
414 | 487 | return newtrace |
|
0 commit comments