@@ -17,8 +17,11 @@ The major differences between this and `TypedVarInfo` are:
1717 b) the values have been specified with the corret shapes.
1818
1919# Examples
20- ```jldoctest; setup=:(using Distributions, Random)
20+ ```jldoctest; setup=:(using Distributions)
21+ julia> using StableRNGs
22+
2123julia> @model function demo()
24+ m ~ Normal()
2225 x = Vector{Float64}(undef, 2)
2326 for i in eachindex(x)
2427 x[i] ~ Normal()
@@ -29,59 +32,46 @@ demo (generic function with 1 method)
2932
3033julia> m = demo();
3134
32- julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());
33-
34- julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")`
35- # and thus accessing these values will be type-unstable and slower.
36- _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
37- SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044)
38-
39- julia> # (×) SLOW!!!
40- DynamicPPL.getval(vi, @varname(x[1]))
41- 0.14447203090358265
42-
43- julia> # In addtion, we can only access varnames as they appear in the model!
44- DynamicPPL.getval(vi, @varname(x))
45- ERROR: type NamedTuple has no field x
46- [...]
35+ julia> rng = StableRNG(42);
4736
48- julia> julia> DynamicPPL.getval(vi, @varname(x[1:2]))
49- ERROR: type NamedTuple has no field x[1:2]
50- [...]
37+ julia> ### Sampling ###
38+ ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());
5139
52- julia> # In contrast, if we provide the container for `x`, the `vi` now only
53- # has the key `x` and we access parts of it using indices.
40+ julia> # In the `NamedTuple` version we need to provide the place-holder values for
41+ # the variablse which are using "containers", e.g. `Array`.
42+ # In this case, this means that we need to specify `x` but not `m`.
5443 _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi
55- SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654], ), -2.0573897507053474 )
44+ SimpleVarInfo{NamedTuple{(:x, :m ), Tuple{Vector{Float64}, Float64 }}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952 ), -5.769094411622931 )
5645
5746julia> # (✓) Vroom, vroom! FAST!!!
5847 DynamicPPL.getval(vi, @varname(x[1]))
59- -0.6538238172778861
48+ 1.6642061055583879
6049
6150julia> # We can also access arbitrary varnames pointing to `x`, e.g.
6251 DynamicPPL.getval(vi, @varname(x))
63522-element Vector{Float64}:
64- -0.6538238172778861
65- 0.10742338922309654
53+ 1.6642061055583879
54+ 1.796319600944139
6655
6756julia> DynamicPPL.getval(vi, @varname(x[1:2]))
68572-element view(::Vector{Float64}, 1:2) with eltype Float64:
69- -0.6538238172778861
70- 0.10742338922309654
58+ 1.6642061055583879
59+ 1.796319600944139
60+
61+ julia> # (×) If we don't provide the container...
62+ _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
63+ ERROR: type NamedTuple has no field x
64+ [...]
7165
72- julia> # The better way to handle sampling of variables involving indexing
73- # if one does not know the varnames, is to use a `Dict` as the container instead.
74- # Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e.
75- # how they appear in the model.
66+ julia> # If one does not know the varnames, we can use a `Dict` instead.
7667 _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi
77- SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437 , x[2] => -1.382335836121636 ), -3.4308773745351453 )
68+ SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277 , x[2] => 0.4914514300738121, m => 0.25572200616753643 ), -3.6215377732004237 )
7869
7970julia> # (✓) Sort of fast, but only possible at runtime.
8071 DynamicPPL.getval(vi, @varname(x[1]))
81- 1.1292246244328437
72+ 1.192696983568277
8273
83- julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does
84- # not directly appear in the model.
74+ julia> # In addtion, we can only access varnames as they appear in the model!
8575 DynamicPPL.getval(vi, @varname(x))
8676ERROR: KeyError: key x not found
8777[...]
@@ -136,24 +126,15 @@ end
136126function _getvalue (nt:: NamedTuple , :: Val{sym} , inds= ()) where {sym}
137127 # Use `getproperty` instead of `getfield`
138128 value = getproperty (nt, sym)
129+ # Note that this will return a `view`, even if the resulting value is 0-dim.
130+ # This makes it possible to call `setindex!` on the result later to update
131+ # in place even in the case where are retrieving a single element, e.g. `x[1]`.
139132 return _getindex (value, inds)
140133end
141134
142135# `NamedTuple`
143- function getval (vi:: SimpleVarInfo , vn:: VarName{sym} ) where {sym}
144- # If `sym` is found in `vi.θ` we assume it will be of correct
145- # shape to support `getindex` for `vn.indexing`.
146- # If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`.
147- # This means that we support both the following cases:
148- # 1. `x[1]` has been provided by the user and can be assumed to be
149- # of shape that allows us to call `_getvalue` on it.
150- # 2. `x[1]` was not provided by the user, e.g. possibly obtained by
151- # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`.
152- return if haskey (vi. θ, sym)
153- maybe_unwrap_view (_getvalue (vi. θ, Val {sym} (), vn. indexing))
154- else
155- getproperty (vi. θ, Symbol (vn))
156- end
136+ function getval (vi:: SimpleVarInfo{<:NamedTuple} , vn:: VarName{sym} ) where {sym}
137+ return maybe_unwrap_view (_getvalue (vi. θ, Val {sym} (), vn. indexing))
157138end
158139
159140# `Dict`
@@ -204,14 +185,12 @@ function push!!(
204185 dist:: Distribution ,
205186 gidset:: Set{Selector} ,
206187) where {sym}
207- # If the key is already there, we try to update in place.
208- return if haskey (vi. θ, sym)
209- current = _getvalue (vi. θ, Val {sym} (), vn. indexing)
210- current .= value
211- vi
212- else
213- @set vi. θ = merge (vi. θ, NamedTuple {(Symbol(vn),)} ((value,)))
214- end
188+ # We update in place.
189+ # We need a view into the array, hence we call `_getvalue` directly
190+ # rather than `getval`.
191+ current = _getvalue (vi. θ, Val {sym} (), vn. indexing)
192+ current .= value
193+ return vi
215194end
216195
217196# `Dict`
0 commit comments