Skip to content

Commit a72594f

Browse files
committed
dont allow sampling with indexing when using SimpleVarInfo with NamedTuple unless shapes are specified
1 parent d29dd8f commit a72594f

File tree

1 file changed

+36
-57
lines changed

1 file changed

+36
-57
lines changed

src/simple_varinfo.jl

Lines changed: 36 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ The major differences between this and `TypedVarInfo` are:
1717
b) the values have been specified with the corret shapes.
1818
1919
# Examples
20-
```jldoctest; setup=:(using Distributions, Random)
20+
```jldoctest; setup=:(using Distributions)
21+
julia> using StableRNGs
22+
2123
julia> @model function demo()
24+
m ~ Normal()
2225
x = Vector{Float64}(undef, 2)
2326
for i in eachindex(x)
2427
x[i] ~ Normal()
@@ -29,59 +32,46 @@ demo (generic function with 1 method)
2932
3033
julia> m = demo();
3134
32-
julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());
33-
34-
julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")`
35-
# and thus accessing these values will be type-unstable and slower.
36-
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
37-
SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044)
38-
39-
julia> # (×) SLOW!!!
40-
DynamicPPL.getval(vi, @varname(x[1]))
41-
0.14447203090358265
42-
43-
julia> # In addtion, we can only access varnames as they appear in the model!
44-
DynamicPPL.getval(vi, @varname(x))
45-
ERROR: type NamedTuple has no field x
46-
[...]
35+
julia> rng = StableRNG(42);
4736
48-
julia> julia> DynamicPPL.getval(vi, @varname(x[1:2]))
49-
ERROR: type NamedTuple has no field x[1:2]
50-
[...]
37+
julia> ### Sampling ###
38+
ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext());
5139
52-
julia> # In contrast, if we provide the container for `x`, the `vi` now only
53-
# has the key `x` and we access parts of it using indices.
40+
julia> # In the `NamedTuple` version we need to provide the place-holder values for
41+
# the variablse which are using "containers", e.g. `Array`.
42+
# In this case, this means that we need to specify `x` but not `m`.
5443
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi
55-
SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654],), -2.0573897507053474)
44+
SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931)
5645
5746
julia> # (✓) Vroom, vroom! FAST!!!
5847
DynamicPPL.getval(vi, @varname(x[1]))
59-
-0.6538238172778861
48+
1.6642061055583879
6049
6150
julia> # We can also access arbitrary varnames pointing to `x`, e.g.
6251
DynamicPPL.getval(vi, @varname(x))
6352
2-element Vector{Float64}:
64-
-0.6538238172778861
65-
0.10742338922309654
53+
1.6642061055583879
54+
1.796319600944139
6655
6756
julia> DynamicPPL.getval(vi, @varname(x[1:2]))
6857
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
69-
-0.6538238172778861
70-
0.10742338922309654
58+
1.6642061055583879
59+
1.796319600944139
60+
61+
julia> # (×) If we don't provide the container...
62+
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi
63+
ERROR: type NamedTuple has no field x
64+
[...]
7165
72-
julia> # The better way to handle sampling of variables involving indexing
73-
# if one does not know the varnames, is to use a `Dict` as the container instead.
74-
# Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e.
75-
# how they appear in the model.
66+
julia> # If one does not know the varnames, we can use a `Dict` instead.
7667
_, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi
77-
SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437, x[2] => -1.382335836121636), -3.4308773745351453)
68+
SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237)
7869
7970
julia> # (✓) Sort of fast, but only possible at runtime.
8071
DynamicPPL.getval(vi, @varname(x[1]))
81-
1.1292246244328437
72+
1.192696983568277
8273
83-
julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does
84-
# not directly appear in the model.
74+
julia> # In addtion, we can only access varnames as they appear in the model!
8575
DynamicPPL.getval(vi, @varname(x))
8676
ERROR: KeyError: key x not found
8777
[...]
@@ -136,24 +126,15 @@ end
136126
function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym}
137127
# Use `getproperty` instead of `getfield`
138128
value = getproperty(nt, sym)
129+
# Note that this will return a `view`, even if the resulting value is 0-dim.
130+
# This makes it possible to call `setindex!` on the result later to update
131+
# in place even in the case where are retrieving a single element, e.g. `x[1]`.
139132
return _getindex(value, inds)
140133
end
141134

142135
# `NamedTuple`
143-
function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym}
144-
# If `sym` is found in `vi.θ` we assume it will be of correct
145-
# shape to support `getindex` for `vn.indexing`.
146-
# If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`.
147-
# This means that we support both the following cases:
148-
# 1. `x[1]` has been provided by the user and can be assumed to be
149-
# of shape that allows us to call `_getvalue` on it.
150-
# 2. `x[1]` was not provided by the user, e.g. possibly obtained by
151-
# sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`.
152-
return if haskey(vi.θ, sym)
153-
maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing))
154-
else
155-
getproperty(vi.θ, Symbol(vn))
156-
end
136+
function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym}
137+
return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing))
157138
end
158139

159140
# `Dict`
@@ -204,14 +185,12 @@ function push!!(
204185
dist::Distribution,
205186
gidset::Set{Selector},
206187
) where {sym}
207-
# If the key is already there, we try to update in place.
208-
return if haskey(vi.θ, sym)
209-
current = _getvalue(vi.θ, Val{sym}(), vn.indexing)
210-
current .= value
211-
vi
212-
else
213-
@set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,)))
214-
end
188+
# We update in place.
189+
# We need a view into the array, hence we call `_getvalue` directly
190+
# rather than `getval`.
191+
current = _getvalue(vi.θ, Val{sym}(), vn.indexing)
192+
current .= value
193+
return vi
215194
end
216195

217196
# `Dict`

0 commit comments

Comments
 (0)