|
1 | | -struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model} |
| 1 | +struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} |
2 | 2 | model::M |
3 | 3 | sampler::S |
4 | 4 | varinfo::V |
| 5 | + evaluator::E |
5 | 6 | end |
6 | 7 |
|
7 | | -# needed? |
8 | | -function TracedModel{SampleFromPrior}( |
| 8 | +function TracedModel( |
9 | 9 | model::Model, |
10 | 10 | sampler::AbstractSampler, |
11 | 11 | varinfo::AbstractVarInfo, |
12 | | -) |
13 | | - return TracedModel(model, SampleFromPrior(), varinfo) |
| 12 | +) |
| 13 | + # evaluate!!(m.model, varinfo, SamplingContext(Random.AbstractRNG, m.sampler, DefaultContext())) |
| 14 | + context = SamplingContext(DynamicPPL.Random.GLOBAL_RNG, sampler, DefaultContext()) |
| 15 | + evaluator = _get_evaluator(model, varinfo, context) |
| 16 | + return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator) |
14 | 17 | end |
15 | 18 |
|
16 | | -(f::TracedModel)() = f.model(f.varinfo, f.sampler) |
| 19 | +# Smiliar to `evaluate!!` except that we return the evaluator signature without excutation. |
| 20 | +# TODO: maybe move to DynamicPPL |
| 21 | +@generated function _get_evaluator( |
| 22 | + model::Model{_F,argnames}, varinfo, context |
| 23 | +) where {_F,argnames} |
| 24 | + unwrap_args = [ |
| 25 | + :($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames |
| 26 | + ] |
| 27 | + # We want to give `context` precedence over `model.context` while also |
| 28 | + # preserving the leaf context of `context`. We can do this by |
| 29 | + # 1. Set the leaf context of `model.context` to `leafcontext(context)`. |
| 30 | + # 2. Set leaf context of `context` to the context resulting from (1). |
| 31 | + # The result is: |
| 32 | + # `context` -> `childcontext(context)` -> ... -> `model.context` |
| 33 | + # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` |
| 34 | + return quote |
| 35 | + context_new = DynamicPPL.setleafcontext( |
| 36 | + context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context)) |
| 37 | + ) |
| 38 | + (model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...)) |
| 39 | + end |
| 40 | +end |
17 | 41 |
|
18 | 42 | function Base.copy(trace::AdvancedPS.Trace{<:TracedModel}) |
19 | 43 | f = trace.f |
@@ -46,4 +70,3 @@ function AdvancedPS.reset_logprob!(f::TracedModel) |
46 | 70 | DynamicPPL.resetlogp!!(f.varinfo) |
47 | 71 | return |
48 | 72 | end |
49 | | - |
|
0 commit comments