From e0fec7cab286cf575cac6cf0a91504d813654b25 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 23:05:45 +0100 Subject: [PATCH 1/7] use views whenever possible --- src/compiler.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..c627f6806 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -27,7 +27,7 @@ function isassumption(expr::Union{Symbol,Expr}) true else # Evaluate the LHS - $expr === missing + $(maybe_view(expr)) === missing end end end @@ -36,6 +36,13 @@ end # failsafe: a literal is never an assumption isassumption(expr) = :(false) +# If we're working with, say, a `Symbol`, then we're not going to `view`. +maybe_view(x) = x +maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) + +maybe_unwrap_view(x) = x +maybe_unwrap_view(x::SubArray{<:Any, 0}) = x[1] + """ isliteral(expr) @@ -300,7 +307,7 @@ function generate_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ ) end end @@ -325,7 +332,7 @@ function generate_tilde(left, right) $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), - $left, + $(maybe_view(left)), $vn, $inds, __varinfo__, @@ -344,7 +351,7 @@ function generate_dot_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ ) end end @@ -360,7 +367,7 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn )..., $inds, __varinfo__, @@ -369,7 +376,7 @@ function generate_dot_tilde(left, right) $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), - $left, + $(maybe_view(left)), $vn, $inds, __varinfo__, From 4f76ceb4c138aa7c9cabfcb670edc6fb2dfa9cad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 23:19:58 +0100 Subject: [PATCH 2/7] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c627f6806..4802ace5f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -41,7 +41,7 @@ maybe_view(x) = x maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) maybe_unwrap_view(x) = x -maybe_unwrap_view(x::SubArray{<:Any, 0}) = x[1] +maybe_unwrap_view(x::SubArray{<:Any,0}) = x[1] """ isliteral(expr) @@ -307,7 +307,10 @@ function generate_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ + __context__, + $(DynamicPPL.check_tilde_rhs)($right), + $(maybe_view(left)), + __varinfo__, ) end end @@ -351,7 +354,10 @@ function generate_dot_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ + __context__, + $(DynamicPPL.check_tilde_rhs)($right), + $(maybe_view(left)), + __varinfo__, ) end end From 5a065ec2949b7dbdca7219e8507a05a6e56ca5ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 23:24:45 +0100 Subject: [PATCH 3/7] dont view literals --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c627f6806..a08afe713 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -307,7 +307,7 @@ function generate_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -351,7 +351,7 @@ function generate_dot_tilde(left, right) if isliteral(left) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), __varinfo__ + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end From 2713a77b85f45d73adf5d462c83749513502ade0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jul 2021 23:37:02 +0100 Subject: [PATCH 4/7] fixed the failing tests --- test/turing/loglikelihoods.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/turing/loglikelihoods.jl b/test/turing/loglikelihoods.jl index 2fb991089..81bcf15a1 100644 --- a/test/turing/loglikelihoods.jl +++ b/test/turing/loglikelihoods.jl @@ -32,8 +32,8 @@ results = pointwise_loglikelihoods(model, var_info) var_to_likelihoods = Dict(string(vn) => ℓ for (vn, ℓ) in results) s, m = var_info[SampleFromPrior()] - @test logpdf(Normal(m, √s), xs[1]) == var_to_likelihoods["xs[1]"] - @test logpdf(Normal(m, √s), xs[2]) == var_to_likelihoods["xs[2]"] - @test logpdf(Normal(m, √s), xs[3]) == var_to_likelihoods["xs[3]"] - @test logpdf(Normal(m, √s), y) == var_to_likelihoods["y"] + @test [logpdf(Normal(m, √s), xs[1])] == var_to_likelihoods["xs[1]"] + @test [logpdf(Normal(m, √s), xs[2])] == var_to_likelihoods["xs[2]"] + @test [logpdf(Normal(m, √s), xs[3])] == var_to_likelihoods["xs[3]"] + @test [logpdf(Normal(m, √s), y)] == var_to_likelihoods["y"] end From 74acfda6c98e45af067965318f180ff311f36920 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jul 2021 23:37:43 +0100 Subject: [PATCH 5/7] added a bunch of get_sections to tests to avoid unnecessary warnings --- test/turing/loglikelihoods.jl | 4 +++- test/turing/model.jl | 20 ++++++++++---------- test/turing/prob_macro.jl | 18 ++++++++++++++---- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/test/turing/loglikelihoods.jl b/test/turing/loglikelihoods.jl index 81bcf15a1..5f1c41572 100644 --- a/test/turing/loglikelihoods.jl +++ b/test/turing/loglikelihoods.jl @@ -13,7 +13,9 @@ y = randn() model = demo(xs, y) chain = sample(model, MH(), MCMCThreads(), 100, 2) - var_to_likelihoods = pointwise_loglikelihoods(model, chain) + var_to_likelihoods = pointwise_loglikelihoods( + model, MCMCChains.get_sections(chain, :parameters) + ) @test haskey(var_to_likelihoods, "xs[1]") @test haskey(var_to_likelihoods, "xs[2]") @test haskey(var_to_likelihoods, "xs[3]") diff --git a/test/turing/model.jl b/test/turing/model.jl index c41b2a5be..5dbea9839 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -30,11 +30,11 @@ chain1 = sample(model1, MH(), 100) chain2 = sample(model2, MH(), 100) - res11 = generated_quantities(model1, chain1) - res21 = generated_quantities(model2, chain1) + res11 = generated_quantities(model1, MCMCChains.get_sections(chain1, :parameters)) + res21 = generated_quantities(model2, MCMCChains.get_sections(chain1, :parameters)) - res12 = generated_quantities(model1, chain2) - res22 = generated_quantities(model2, chain2) + res12 = generated_quantities(model1, MCMCChains.get_sections(chain2, :parameters)) + res22 = generated_quantities(model2, MCMCChains.get_sections(chain2, :parameters)) # Check that the two different models produce the same values for # the same chains. @@ -43,8 +43,8 @@ # Ensure that they're not all the same (some can be, because rejected samples) @test any(res12[1:(end - 1)] .!= res12[2:end]) - test_setval!(model1, chain1) - test_setval!(model2, chain2) + test_setval!(model1, MCMCChains.get_sections(chain1, :parameters)) + test_setval!(model2, MCMCChains.get_sections(chain2, :parameters)) # Next level @model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV} @@ -79,11 +79,11 @@ chain3 = sample(model3, MH(), 100) chain4 = sample(model4, MH(), 100) - res33 = generated_quantities(model3, chain3) - res43 = generated_quantities(model4, chain3) + res33 = generated_quantities(model3, MCMCChains.get_sections(chain3, :parameters)) + res43 = generated_quantities(model4, MCMCChains.get_sections(chain3, :parameters)) - res34 = generated_quantities(model3, chain4) - res44 = generated_quantities(model4, chain4) + res34 = generated_quantities(model3, MCMCChains.get_sections(chain4, :parameters)) + res44 = generated_quantities(model4, MCMCChains.get_sections(chain4, :parameters)) # Check that the two different models produce the same values for # the same chains. diff --git a/test/turing/prob_macro.jl b/test/turing/prob_macro.jl index 0eb2a1290..5f0580b55 100644 --- a/test/turing/prob_macro.jl +++ b/test/turing/prob_macro.jl @@ -11,7 +11,9 @@ model = demo(xval) varinfo = VarInfo(model) - chain = sample(model, IS(), iters; save_state=true) + chain = MCMCChains.get_sections( + sample(model, IS(), iters; save_state=true), :parameters + ) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) lps = logpdf.(Normal.(chain["m"], 1), xval) @test logprob"x = xval | chain = chain" == lps @@ -40,7 +42,9 @@ model = demo(xval) varinfo = VarInfo(model) - chain = sample(model, HMC(0.5, 1), iters; save_state=true) + chain = MCMCChains.get_sections( + sample(model, HMC(0.5, 1), iters; save_state=true), :parameters + ) chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) names = namesingroup(chain, "m") @@ -74,7 +78,10 @@ group = rand(1:4, 100) n_groups = 4 - chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true) + chain1 = MCMCChains.get_sections( + sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true), + :parameters, + ) logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1" @model function model2(y, group, n_groups) @@ -85,7 +92,10 @@ end end - chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true) + chain2 = MCMCChains.get_sections( + sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true), + :parameters, + ) logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2" end end From 80ff05344fe1bd09c56679631bb59b82b8f912fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jul 2021 23:39:15 +0100 Subject: [PATCH 6/7] formatting --- test/compat/ad.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 1d8b02c55..3a8058ca9 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -8,7 +8,8 @@ return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) + - logpdf(dist, 1.5) + logpdf(dist, 2.0) + logpdf(dist, 1.5) + + logpdf(dist, 2.0) end test_model_ad(gdemo_default, logp_gdemo_default) From b18918d4e6272a2015b30eaaf961c97eb5d2a93a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 13 Jul 2021 23:47:45 +0100 Subject: [PATCH 7/7] added comment to describe maybe_unwrap_view --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 8ad31ed17..91fe78e2b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -40,6 +40,9 @@ isassumption(expr) = :(false) maybe_view(x) = x maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) +# If the result of a `view` is a zero-dim array then it's just a +# single element. Likely the rest is expecting type `eltype(x)`, hence +# we extract the value rather than passing the array. maybe_unwrap_view(x) = x maybe_unwrap_view(x::SubArray{<:Any,0}) = x[1]