From c13e85ad0e05f73f31108238f622f4ed3e636f88 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 21 Nov 2020 22:10:47 +1100 Subject: [PATCH 1/5] allow redefinition of inputs in logprob --- src/context_implementations.jl | 4 ++-- src/prob_macro.jl | 2 +- test/prob_macro.jl | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cf7e7f7f3..04f9c4974 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -30,7 +30,7 @@ function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) return _tilde(rng, sampler, right, vn, vi) end function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing + if ctx.vars !== nothing && ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end @@ -169,7 +169,7 @@ function dot_tilde( inds, vi, ) - if ctx.vars !== nothing + if ctx.vars !== nothing && ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) set_val!(vi, vns, dist, var) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 6d7b5ebee..52d1c7466 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -190,7 +190,7 @@ function Distributions.loglikelihood( if isdefined(right, :chain) # Element-wise likelihood for each value in chain chain = right.chain - ctx = LikelihoodContext() + ctx = LikelihoodContext(right) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) logps = map(iters) do (sample_idx, chain_idx) setval!(vi, chain, sample_idx, chain_idx) diff --git a/test/prob_macro.jl b/test/prob_macro.jl index 5bebbffdf..01359029b 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -128,4 +128,21 @@ Random.seed!(129) chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true) logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2" end + + @testset "issue190" begin + @model function gdemo(x, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ filldist(Normal(m, sqrt(s)), length(y)) + for i in 1:length(y) + y[i] ~ Normal(x[i], sqrt(s)) + end + end + model_gdemo = gdemo([1.0, 0.0], [1.5, 0.0]) + c2 = sample(model_gdemo, NUTS(0.65), 100) + result1 = prob"y = [1.5] | chain=c2, model = model_gdemo, x = [1.0]" + result2 = map(c2[:s]) do s + exp(logpdf(Normal(1.0, sqrt(s)), 1.5)) + end + end end From bca62c4318495ae0054f5d46f3565335c407fcba Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 21 Nov 2020 22:18:13 +1100 Subject: [PATCH 2/5] remove unnecessary check --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 04f9c4974..6b3542acd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -30,7 +30,7 @@ function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) return _tilde(rng, sampler, right, vn, vi) end function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing && ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end @@ -169,7 +169,7 @@ function dot_tilde( inds, vi, ) - if ctx.vars !== nothing && ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) vns, dist = get_vns_and_dist(right, var, vn) set_val!(vi, vns, dist, var) From aec4d5bde78ff0cc7834a33e01d862816778ccf6 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sat, 21 Nov 2020 22:45:27 +1100 Subject: [PATCH 3/5] Update test/prob_macro.jl Co-authored-by: David Widmann --- test/prob_macro.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prob_macro.jl b/test/prob_macro.jl index 01359029b..6f6f9da58 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -142,7 +142,7 @@ Random.seed!(129) c2 = sample(model_gdemo, NUTS(0.65), 100) result1 = prob"y = [1.5] | chain=c2, model = model_gdemo, x = [1.0]" result2 = map(c2[:s]) do s - exp(logpdf(Normal(1.0, sqrt(s)), 1.5)) + pdf(Normal(1, sqrt(s)), 1.5) end end end From 4de60603f95a8e47ac3b7b08f00affc0d5a2678b Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 21 Nov 2020 23:01:14 +1100 Subject: [PATCH 4/5] avoid Turing and revert pdf(..) change --- test/prob_macro.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/prob_macro.jl b/test/prob_macro.jl index 6f6f9da58..28a8ee6cc 100644 --- a/test/prob_macro.jl +++ b/test/prob_macro.jl @@ -138,11 +138,13 @@ Random.seed!(129) y[i] ~ Normal(x[i], sqrt(s)) end end + c = Chains(rand(10, 2), [:m, :s]) model_gdemo = gdemo([1.0, 0.0], [1.5, 0.0]) - c2 = sample(model_gdemo, NUTS(0.65), 100) - result1 = prob"y = [1.5] | chain=c2, model = model_gdemo, x = [1.0]" - result2 = map(c2[:s]) do s - pdf(Normal(1, sqrt(s)), 1.5) + r1 = prob"y = [1.5] | chain=c, model = model_gdemo, x = [1.0]" + r2 = map(c[:s]) do s + # exp(logpdf(..)) not pdf because this is exactly what the prob"" macro does, so we test r1 == r2 + exp(logpdf(Normal(1, sqrt(s)), 1.5)) end + @test r1 == r2 end end From 84e8ee4fd3dc216d73452a55247c6cdbbdad6156 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 22 Nov 2020 21:57:56 +1100 Subject: [PATCH 5/5] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e38ab9eee..1a07c0b4a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.9.7" +version = "0.9.8" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"