Skip to content

setval! is being weird #167

@torfjelde

Description

@torfjelde

So I'm currently making a PR for generated_quantities, and ran into the following:

julia> using DynamicPPL, Turing

julia> Turing.turnprogress(false);
[ Info: [Turing]: progress logging is disabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as false

julia> @model function demo_fails(xs, ::Type{TV} = Vector{Float64}) where {TV}
           m = TV(undef, 2)
           for i in 1:2
               m[i] ~ Normal(0, 1)
           end

           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end

           return (m, )
       end;

julia> xs = randn(3);

julia> model_fails = demo_fails(xs);

julia> chain_fails = sample(model_fails, NUTS(0.65), 100);
┌ Info: Found initial step size
└   ϵ = 1.6

julia> var_info = VarInfo(model_fails);

julia> spl = SampleFromPrior();

julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
 -0.34786841389137324
  0.4703320351984347

julia> res_0 = model_fails(var_info, spl)
([-0.34786841389137324, 0.4703320351984347],)

julia> DynamicPPL.setval!(var_info, chain_fails, 1, 1);

julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
 -0.34786841389137324
  0.4703320351984347

julia> res_1 = model_fails(var_info, spl) # <= has NOT changed since `res_0`!!!
([-0.34786841389137324, 0.4703320351984347],)
 ...

In contrast, the following works just fine:

julia> @model function demo_works(xs)
           m ~ MvNormal(2, 1.)
           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end

           return (m, )
       end;

julia> model_works = demo_works(xs);

julia> chain_works = sample(model_works, NUTS(0.65), 100);
┌ Info: Found initial step size
└   ϵ = 1.6

julia> var_info = VarInfo(model_works);

julia> spl = SampleFromPrior();

julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
 -0.38315867392034214
  1.2157641175253535

julia> res_0 = model_works(var_info, spl)
([-0.38315867392034214, 1.2157641175253535],)

julia> DynamicPPL.setval!(var_info, chain_works, 1, 1);

julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
 -0.5500762603588643
 -0.8341982280432694

julia> res_1 = model_works(var_info, spl) # <= has indeed changed since `res_0`
([-0.5500762603588643, -0.8341982280432694],)
 ...

So the in first example setval! fails.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions