1+ """
2+ TrackedValue{T}
3+
4+ A struct that wraps something on the right-hand side of `:=`. This is needed
5+ because the DynamicPPL compiler actually converts `lhs := rhs` to `lhs ~
6+ TrackedValue(rhs)` (so that we can hit the `tilde_assume` method below). Having
7+ the rhs wrapped in a TrackedValue makes sure that the logpdf of the rhs is not
8+ computed (as it wouldn't make sense).
9+ """
110struct TrackedValue{T}
211 value:: T
312end
@@ -24,17 +33,27 @@ struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
2433 values:: OrderedDict
2534 " whether to extract variables on the LHS of :="
2635 include_colon_eq:: Bool
36+ " varnames to be tracked; `nothing` means track all varnames"
37+ tracked_varnames:: Union{Nothing,Array{<:VarName}}
2738 " child context"
2839 context:: C
2940end
30- function ValuesAsInModelContext (include_colon_eq, context:: AbstractContext )
31- return ValuesAsInModelContext (OrderedDict (), include_colon_eq, context)
41+ function ValuesAsInModelContext (
42+ include_colon_eq:: Bool ,
43+ tracked_varnames:: Union{Nothing,Array{<:VarName}} ,
44+ context:: AbstractContext ,
45+ )
46+ return ValuesAsInModelContext (
47+ OrderedDict (), include_colon_eq, tracked_varnames, context
48+ )
3249end
3350
3451NodeTrait (:: ValuesAsInModelContext ) = IsParent ()
3552childcontext (context:: ValuesAsInModelContext ) = context. context
3653function setchildcontext (context:: ValuesAsInModelContext , child)
37- return ValuesAsInModelContext (context. values, context. include_colon_eq, child)
54+ return ValuesAsInModelContext (
55+ context. values, context. include_colon_eq, context. tracked_varnames, child
56+ )
3857end
3958
4059is_extracting_values (context:: ValuesAsInModelContext ) = context. include_colon_eq
6382
6483# `tilde_asssume`
6584function tilde_assume (context:: ValuesAsInModelContext , right, vn, vi)
66- if is_tracked_value (right)
85+ is_tracked_value_right = is_tracked_value (right)
86+ if is_tracked_value_right
6787 value = right. value
6888 logp = zero (getlogp (vi))
6989 else
7090 value, logp, vi = tilde_assume (childcontext (context), right, vn, vi)
7191 end
7292 # Save the value.
73- push! (context, vn, value)
74- # Save the value.
93+ if is_tracked_value_right ||
94+ isnothing (context. tracked_varnames) ||
95+ any (tracked_vn -> subsumes (tracked_vn, vn), context. tracked_varnames)
96+ push! (context, vn, value)
97+ end
7598 # Pass on.
7699 return value, logp, vi
77100end
78101function tilde_assume (
79102 rng:: Random.AbstractRNG , context:: ValuesAsInModelContext , sampler, right, vn, vi
80103)
81- if is_tracked_value (right)
104+ is_tracked_value_right = is_tracked_value (right)
105+ if is_tracked_value_right
82106 value = right. value
83107 logp = zero (getlogp (vi))
84108 else
85109 value, logp, vi = tilde_assume (rng, childcontext (context), sampler, right, vn, vi)
86110 end
87111 # Save the value.
88- push! (context, vn, value)
112+ if is_tracked_value_right ||
113+ isnothing (context. tracked_varnames) ||
114+ any (tracked_vn -> subsumes (tracked_vn, vn), context. tracked_varnames)
115+ push! (context, vn, value)
116+ end
89117 # Pass on.
90118 return value, logp, vi
91119end
@@ -167,9 +195,39 @@ function values_as_in_model(
167195 model:: Model ,
168196 include_colon_eq:: Bool ,
169197 varinfo:: AbstractVarInfo ,
198+ tracked_varnames= tracked_varnames (model),
170199 context:: AbstractContext = DefaultContext (),
171200)
172- context = ValuesAsInModelContext (include_colon_eq, context)
201+ tracked_varnames = isnothing (tracked_varnames) ? nothing : collect (tracked_varnames)
202+ context = ValuesAsInModelContext (include_colon_eq, tracked_varnames, context)
173203 evaluate!! (model, varinfo, context)
174204 return context. values
175205end
206+
207+ """
208+ tracked_varnames(model::Model)
209+
210+ Returns a set of `VarName`s that the model should track.
211+
212+ By default, this returns `nothing`, which means that all `VarName`s should be
213+ tracked.
214+
215+ If you want to track only a subset of `VarName`s, you can override this method
216+ in your model definition:
217+
218+ ```julia
219+ @model function mymodel()
220+ x ~ Normal()
221+ y ~ Normal(x, 1)
222+ end
223+
224+ DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(y)]
225+ ```
226+
227+ Then, when you sample from `mymodel()`, only the value of `y` will be tracked
228+ (and not `x`).
229+
230+ Note that quantities on the left-hand side of `:=` are always tracked, and will
231+ ignore the varnames specified in this method.
232+ """
233+ tracked_varnames (:: Model ) = nothing
0 commit comments