-
Notifications
You must be signed in to change notification settings - Fork 37
Closed
Description
So I'm currently making a PR for generated_quantities, and ran into the following:
julia> using DynamicPPL, Turing
julia> Turing.turnprogress(false);
[ Info: [Turing]: progress logging is disabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as false
julia> @model function demo_fails(xs, ::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, 2)
for i in 1:2
m[i] ~ Normal(0, 1)
end
for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.)
end
return (m, )
end;
julia> xs = randn(3);
julia> model_fails = demo_fails(xs);
julia> chain_fails = sample(model_fails, NUTS(0.65), 100);
┌ Info: Found initial step size
└ ϵ = 1.6
julia> var_info = VarInfo(model_fails);
julia> spl = SampleFromPrior();
julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
-0.34786841389137324
0.4703320351984347
julia> res_0 = model_fails(var_info, spl)
([-0.34786841389137324, 0.4703320351984347],)
julia> DynamicPPL.setval!(var_info, chain_fails, 1, 1);
julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
-0.34786841389137324
0.4703320351984347
julia> res_1 = model_fails(var_info, spl) # <= has NOT changed since `res_0`!!!
([-0.34786841389137324, 0.4703320351984347],)
...In contrast, the following works just fine:
julia> @model function demo_works(xs)
m ~ MvNormal(2, 1.)
for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.)
end
return (m, )
end;
julia> model_works = demo_works(xs);
julia> chain_works = sample(model_works, NUTS(0.65), 100);
┌ Info: Found initial step size
└ ϵ = 1.6
julia> var_info = VarInfo(model_works);
julia> spl = SampleFromPrior();
julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
-0.38315867392034214
1.2157641175253535
julia> res_0 = model_works(var_info, spl)
([-0.38315867392034214, 1.2157641175253535],)
julia> DynamicPPL.setval!(var_info, chain_works, 1, 1);
julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
-0.5500762603588643
-0.8341982280432694
julia> res_1 = model_works(var_info, spl) # <= has indeed changed since `res_0`
([-0.5500762603588643, -0.8341982280432694],)
...So the in first example setval! fails.
Metadata
Metadata
Assignees
Labels
No labels