@@ -130,6 +130,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
130130 test_base!! (SimpleVarInfo (Dict ()))
131131 test_base!! (SimpleVarInfo (DynamicPPL. VarNamedVector ()))
132132 end
133+
133134 @testset " flags" begin
134135 # Test flag setting:
135136 # is_flagged, set_flag!, unset_flag!
@@ -187,6 +188,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
187188 setgid! (vi, gid2, vn)
188189 @test meta. x. gids[meta. x. idcs[vn]] == Set ([gid1, gid2])
189190 end
191+
190192 @testset " setval! & setval_and_resample!" begin
191193 @model function testmodel (x)
192194 n = length (x)
@@ -339,6 +341,52 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
339341 @test vals_prev == vi. metadata. x. vals
340342 end
341343
344+ @testset " setval! on chain" begin
345+ # Define a helper function
346+ """
347+ test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
348+
349+ Test `setval!` on `model` and `chain`.
350+
351+ Worth noting that this only supports models containing symbols of the forms
352+ `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
353+ """
354+ function test_setval! (model, chain; sample_idx= 1 , chain_idx= 1 )
355+ var_info = VarInfo (model)
356+ spl = SampleFromPrior ()
357+ θ_old = var_info[spl]
358+ DynamicPPL. setval! (var_info, chain, sample_idx, chain_idx)
359+ θ_new = var_info[spl]
360+ @test θ_old != θ_new
361+ vals = DynamicPPL. values_as (var_info, OrderedDict)
362+ iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
363+ for (n, v) in mapreduce (collect, vcat, iters)
364+ n = string (n)
365+ if Symbol (n) ∉ keys (chain)
366+ # Assume it's a group
367+ chain_val = vec (
368+ MCMCChains. group (chain, Symbol (n)). value[sample_idx, :, chain_idx]
369+ )
370+ v_true = vec (v)
371+ else
372+ chain_val = chain[sample_idx, n, chain_idx]
373+ v_true = v
374+ end
375+
376+ @test v_true == chain_val
377+ end
378+ end
379+
380+ @testset " $model " for model in DynamicPPL. TestUtils. DEMO_MODELS
381+ chain = make_chain_from_prior (model, 10 )
382+ # A simple way of checking that the computation is determinstic: run twice and compare.
383+ res1 = generated_quantities (model, MCMCChains. get_sections (chain, :parameters ))
384+ res2 = generated_quantities (model, MCMCChains. get_sections (chain, :parameters ))
385+ @test all (res1 .== res2)
386+ test_setval! (model, MCMCChains. get_sections (chain, :parameters ))
387+ end
388+ end
389+
342390 @testset " istrans" begin
343391 @model demo_constrained () = x ~ truncated (Normal (), 0 , Inf )
344392 model = demo_constrained ()
0 commit comments