@@ -90,12 +90,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
9090 samples = (; samples_dict... )
9191 samples = modify_value_representation (samples) # `modify_value_representation` defined in test/test_util.jl
9292 @test logpriors[i] ≈
93- DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
93+ DynamicPPL. TestUtils. logprior_true (model, samples[:s ], samples[:m ])
9494 @test loglikelihoods[i] ≈ DynamicPPL. TestUtils. loglikelihood_true (
9595 model, samples[:s ], samples[:m ]
9696 )
9797 @test logjoints[i] ≈
98- DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
98+ DynamicPPL. TestUtils. logjoint_true (model, samples[:s ], samples[:m ])
9999 end
100100 end
101101 end
@@ -283,10 +283,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
283283 # Ensure log-probability computations are implemented.
284284 @test logprior (model, x) ≈ DynamicPPL. TestUtils. logprior_true (model, x... )
285285 @test loglikelihood (model, x) ≈
286- DynamicPPL. TestUtils. loglikelihood_true (model, x... )
286+ DynamicPPL. TestUtils. loglikelihood_true (model, x... )
287287 @test logjoint (model, x) ≈ DynamicPPL. TestUtils. logjoint_true (model, x... )
288288 @test logjoint (model, x) !=
289- DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
289+ DynamicPPL. TestUtils. logjoint_true_with_logabsdet_jacobian (model, x... )
290290 # Ensure `varnames` is implemented.
291291 vi = last (
292292 DynamicPPL. evaluate!! (
@@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
383383 example_values = DynamicPPL. TestUtils. rand_prior_true (model)
384384 varinfos = DynamicPPL. TestUtils. setup_varinfos (model, example_values, vns)
385385 @testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos
386- realizations = values_as_in_model (model, varinfo)
386+ # We can set the include_colon_eq arg to false because none of
387+ # the demo models contain :=. The behaviour when
388+ # include_colon_eq is true is tested in test/compiler.jl
389+ realizations = values_as_in_model (model, false , varinfo)
387390 # Ensure that all variables are found.
388391 vns_found = collect (keys (realizations))
389392 @test vns ∩ vns_found == vns ∪ vns_found
@@ -432,72 +435,85 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
432435
433436 @testset " predict" begin
434437 @testset " with MCMCChains.Chains" begin
435- DynamicPPL. Random. seed! (100 )
436-
437438 @model function linear_reg (x, y, σ= 0.1 )
438439 β ~ Normal (0 , 1 )
439440 for i in eachindex (y)
440441 y[i] ~ Normal (β * x[i], σ)
441442 end
443+ # Insert a := block to test that it is not included in predictions
444+ σ2 := σ^ 2
442445 end
443446
444- @model function linear_reg_vec (x, y, σ= 0.1 )
445- β ~ Normal (0 , 1 )
446- return y ~ MvNormal (β .* x, σ^ 2 * I)
447- end
448-
447+ # Construct a chain with 'sampled values' of β
449448 ground_truth_β = 2
450449 β_chain = MCMCChains. Chains (rand (Normal (ground_truth_β, 0.002 ), 1000 ), [:β ])
451450
451+ # Generate predictions from that chain
452452 xs_test = [10 + 0.1 , 10 + 2 * 0.1 ]
453453 m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
454454 predictions = DynamicPPL. predict (m_lin_reg_test, β_chain)
455455
456- ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
457- @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
458- @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
459-
460- # Ensure that `rng` is respected
461- rng = MersenneTwister (42 )
462- predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
463- predictions2 = DynamicPPL. predict (
464- MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
465- )
466- @test all (Array (predictions1) .== Array (predictions2))
467-
468- # Predict on two last indices for vectorized
469- m_lin_reg_test = linear_reg_vec (xs_test, missing )
470- predictions_vec = DynamicPPL. predict (m_lin_reg_test, β_chain)
471- ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
472-
473- @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
474- @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
456+ # Also test a vectorized model
457+ @model function linear_reg_vec (x, y, σ= 0.1 )
458+ β ~ Normal (0 , 1 )
459+ return y ~ MvNormal (β .* x, σ^ 2 * I)
460+ end
461+ m_lin_reg_test_vec = linear_reg_vec (xs_test, missing )
475462
476- # Multiple chains
477- multiple_β_chain = MCMCChains. Chains (
478- reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
479- )
480- m_lin_reg_test = linear_reg (xs_test, fill (missing , length (xs_test)))
481- predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
482- @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
463+ @testset " variables in chain" begin
464+ # Note that this also checks that variables on the lhs of :=,
465+ # such as σ2, are not included in the resulting chain
466+ @test Set (keys (predictions)) == Set ([Symbol (" y[1]" ), Symbol (" y[2]" )])
467+ end
483468
484- for chain_idx in MCMCChains . chains (multiple_β_chain)
485- ys_pred = vec (mean (Array (group (predictions[:, :, chain_idx] , :y )); dims= 1 ))
469+ @testset " accuracy " begin
470+ ys_pred = vec (mean (Array (group (predictions, :y )); dims= 1 ))
486471 @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
487472 @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
488473 end
489474
490- # Predict on two last indices for vectorized
491- m_lin_reg_test = linear_reg_vec (xs_test, missing )
492- predictions_vec = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
493-
494- for chain_idx in MCMCChains. chains (multiple_β_chain)
495- ys_pred_vec = vec (
496- mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
475+ @testset " ensure that rng is respected" begin
476+ rng = MersenneTwister (42 )
477+ predictions1 = DynamicPPL. predict (rng, m_lin_reg_test, β_chain[1 : 2 ])
478+ predictions2 = DynamicPPL. predict (
479+ MersenneTwister (42 ), m_lin_reg_test, β_chain[1 : 2 ]
497480 )
481+ @test all (Array (predictions1) .== Array (predictions2))
482+ end
483+
484+ @testset " accuracy on vectorized model" begin
485+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, β_chain)
486+ ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims= 1 ))
487+
498488 @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
499489 @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
500490 end
491+
492+ @testset " prediction from multiple chains" begin
493+ # Normal linreg model
494+ multiple_β_chain = MCMCChains. Chains (
495+ reshape (rand (Normal (ground_truth_β, 0.002 ), 1000 , 2 ), 1000 , 1 , 2 ), [:β ]
496+ )
497+ predictions = DynamicPPL. predict (m_lin_reg_test, multiple_β_chain)
498+ @test size (multiple_β_chain, 3 ) == size (predictions, 3 )
499+
500+ for chain_idx in MCMCChains. chains (multiple_β_chain)
501+ ys_pred = vec (mean (Array (group (predictions[:, :, chain_idx], :y )); dims= 1 ))
502+ @test ys_pred[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
503+ @test ys_pred[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
504+ end
505+
506+ # Vectorized linreg model
507+ predictions_vec = DynamicPPL. predict (m_lin_reg_test_vec, multiple_β_chain)
508+
509+ for chain_idx in MCMCChains. chains (multiple_β_chain)
510+ ys_pred_vec = vec (
511+ mean (Array (group (predictions_vec[:, :, chain_idx], :y )); dims= 1 )
512+ )
513+ @test ys_pred_vec[1 ] ≈ ground_truth_β * xs_test[1 ] atol = 0.01
514+ @test ys_pred_vec[2 ] ≈ ground_truth_β * xs_test[2 ] atol = 0.01
515+ end
516+ end
501517 end
502518
503519 @testset " with AbstractVector{<:AbstractVarInfo}" begin
@@ -524,7 +540,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
524540
525541 @test size (predicted_vis) == size (chain)
526542 @test Set (keys (predicted_vis[1 ])) ==
527- Set ([@varname (β), @varname (y[1 ]), @varname (y[2 ])])
543+ Set ([@varname (β), @varname (y[1 ]), @varname (y[2 ])])
528544 # because β samples are from the prior, the std will be larger
529545 @test mean ([
530546 predicted_vis[i][@varname (y[1 ])] for i in eachindex (predicted_vis)
0 commit comments