Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ This is used mainly to unwrap `NamedDist` distributions and adjust the indices o
variables.

# Example
```jldoctest; setup=:(using Distributions)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal([1.0, 1.0], [1.0 0.0; 0.0 1.0]), randn(2, 2), @varname(x)); string(vns[end])
```jldoctest; setup=:(using Distributions, LinearAlgebra)
julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string(vns[end])
"x[:,2]"

julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end])
Expand Down
2 changes: 1 addition & 1 deletion src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ y .~ Normal(μ, σ)
```
3. using `MvNormal`:
```julia
y ~ MvNormal(fill(μ, n), Diagonal(fill(σ, n)))
y ~ MvNormal(fill(μ, n), σ^2 * I)
```

In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables,
Expand Down
2 changes: 1 addition & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Let `md` be an instance of `Metadata`:
To make `md::Metadata` type stable, all the `md.vns` must have the same symbol
and distribution type. However, one can have a Julia variable, say `x`, that is a
matrix or a hierarchical array sampled in partitions, e.g.
`x[1][:] ~ MvNormal(zeros(2), 1.0); x[2][:] ~ MvNormal(ones(2), 1.0)`, and is managed by
`x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`, and is managed by
a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the
same type. Type unstable `Metadata` will still work but will have inferior performance.
When sampling, the first iteration uses a type unstable `Metadata` for all the
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AbstractMCMC = "2.1, 3.0"
AbstractPPL = "0.2"
Bijectors = "0.9.5"
Distributions = "< 0.25.11"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
ForwardDiff = "0.10.12"
Expand Down
16 changes: 8 additions & 8 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ end
# test DPPL#61
@model function testmodel_missing5(z)
m ~ Normal()
z[1:end] ~ MvNormal(fill(m, length(z)), 1.0)
z[1:end] ~ MvNormal(fill(m, length(z)), I)
return m
end
model = testmodel_missing5(rand(10))
Expand Down Expand Up @@ -379,7 +379,7 @@ end

# AR1 model. Dynamic prefixing.
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
η ~ MvNormal(num_steps, 1.0)
η ~ MvNormal(zeros(num_steps), I)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
Expand All @@ -400,7 +400,7 @@ end
num_obs = length(y)
@inbounds for i in 1:num_obs
x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ)
y[i] ~ MvNormal(x, 0.1)
y[i] ~ MvNormal(x, 0.01 * I)
end
end

Expand All @@ -419,7 +419,7 @@ end
x = Normal()
@test DynamicPPL.check_tilde_rhs(x) === x

x = [Laplace(), Normal(), MvNormal(3, 1.0)]
x = [Laplace(), Normal(), MvNormal(zeros(3), I)]
@test DynamicPPL.check_tilde_rhs(x) === x
end
@testset "isliteral" begin
Expand All @@ -436,14 +436,14 @@ end
# Verify that we indeed can parse this.
@test @model(function array_literal_model()
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
m ~ MvNormal(zeros(2), I)
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
end) isa Function

@model function array_literal_model2()
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
return [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
m ~ MvNormal(zeros(2), I)
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
end

@test array_literal_model2()() == [10.0, 10.0]
Expand Down
5 changes: 2 additions & 3 deletions test/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
# https://github.com/TuringLang/DynamicPPL.jl/issues/129
@testset "#129" begin
@model function test(x)
μ ~ MvNormal(fill(0, 2), 2.0)
μ ~ MvNormal(zeros(2), 4 * I)
z = Vector{Int}(undef, length(x))
# `z .~ Categorical.(ps)` cannot be parsed by Julia 1.0
(.~)(z, Categorical.(fill([0.5, 0.5], length(x))))
z .~ Categorical.(fill([0.5, 0.5], length(x)))
for i in 1:length(x)
x[i] ~ Normal(μ[z[i]], 0.1)
end
Expand Down
30 changes: 15 additions & 15 deletions test/loglikelihoods.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@model function gdemo1(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
return x ~ MvNormal(m, 0.5)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@model function gdemo2(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
return x ~ MvNormal(m, 0.5)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo3(x=10 * ones(2))
@model function gdemo3(x=[10.0, 10.0])
# Multivariate `assume` and `observe`
m ~ MvNormal(length(x), 1.0)
return x ~ MvNormal(m, 0.5)
m ~ MvNormal(zero(x), I)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@model function gdemo4(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
Expand All @@ -33,16 +33,16 @@ end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function gdemo5(x=10 * ones(1))
@model function gdemo5(x=[10.0])
# `assume` and `dot_observe`
m ~ Normal()
return x .~ Normal(m, 0.5)
end

@model function gdemo6(::Type{TV}=Vector{Float64}) where {TV}
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
return [10.0, 10.0] ~ MvNormal(m, 0.5)
m ~ MvNormal(zeros(2), I)
return [10.0, 10.0] ~ MvNormal(m, 0.25 * I)
end

@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV}
Expand Down Expand Up @@ -76,23 +76,23 @@ end
end

@model function _likelihood_dot_observe(m, x)
return x ~ MvNormal(m, 0.5)
return x ~ MvNormal(m, 0.25 * I)
end

@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV}
@model function gdemo10(x=[10.0, 10.0], ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)
end

@model function gdemo11(x=10 * ones(2, 1), ::Type{TV}=Vector{Float64}) where {TV}
@model function gdemo11(x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Dotted observe for `Matrix`.
return x .~ MvNormal(m, 0.5)
return x .~ MvNormal(m, 0.25 * I)
end

const gdemo_models = (
Expand Down
8 changes: 4 additions & 4 deletions test/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@
@testset "vector" begin
n = 5
@model function demo(x, n=n)
m ~ MvNormal(n, 1.0)
return x ~ MvNormal(m, 1.0)
m ~ MvNormal(zeros(n), I)
return x ~ MvNormal(m, I)
end
mval = rand(n)
xval = rand(n)
iters = 1000

logprior = logpdf(MvNormal(n, 1.0), mval)
loglike = logpdf(MvNormal(mval, 1.0), xval)
logprior = logpdf(MvNormal(zeros(n), I), mval)
loglike = logpdf(MvNormal(mval, I), xval)
logjoint = logprior + loglike

model = demo(xval)
Expand Down
16 changes: 8 additions & 8 deletions test/turing/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@

res = sample(vdemo1b(x), alg, 250)

D = 2
@model function vdemo2(x)
μ ~ MvNormal(zeros(D), ones(D))
@. x ~ $(MvNormal(μ, ones(D)))
μ ~ MvNormal(zeros(size(x, 1)), I)
@. x ~ $(MvNormal(μ, I))
end

D = 2
alg = HMC(0.01, 5)
res = sample(vdemo2(randn(D, 100)), alg, 250)

Expand All @@ -206,7 +206,7 @@

t_vec = @elapsed res = sample(vdemo4(), alg, 1000)

@model vdemo5() = x ~ MvNormal(zeros(N), 2 * ones(N))
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)

t_mv = @elapsed res = sample(vdemo5(), alg, 1000)

Expand Down Expand Up @@ -243,12 +243,12 @@
x = randn(100)
res = sample(vdemo1(x), alg, 250)

D = 2
@model function vdemo2(x)
μ ~ MvNormal(zeros(D), ones(D))
return x .~ MvNormal(μ, ones(D))
μ ~ MvNormal(zeros(size(x, 1)), I)
return x .~ MvNormal(μ, I)
end

D = 2
alg = HMC(0.01, 5)
res = sample(vdemo2(randn(D, 100)), alg, 250)

Expand All @@ -274,7 +274,7 @@

t_vec = @elapsed res = sample(vdemo4(), alg, 1000)

@model vdemo5() = x ~ MvNormal(zeros(N), 2 * ones(N))
@model vdemo5() = x ~ MvNormal(zeros(N), 4 * I)

t_mv = @elapsed res = sample(vdemo5(), alg, 1000)

Expand Down
6 changes: 3 additions & 3 deletions test/turing/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
end

@model function demo2(xs)
m ~ MvNormal(2, 1.0)
m ~ MvNormal(zeros(2), I)

for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.0)
Expand Down Expand Up @@ -50,7 +50,7 @@
@model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV}
m = Vector{TV}(undef, 2)
for i in 1:length(m)
m[i] ~ MvNormal(2, 1.0)
m[i] ~ MvNormal(zeros(2), I)
end

for i in eachindex(xs)
Expand All @@ -63,7 +63,7 @@
@model function demo4(xs, ::Type{TV}=Vector{Vector{Float64}}) where {TV}
m = TV(undef, 2)
for i in 1:length(m)
m[i] ~ MvNormal(2, 1.0)
m[i] ~ MvNormal(zeros(2), I)
end

for i in eachindex(xs)
Expand Down
8 changes: 4 additions & 4 deletions test/turing/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
@testset "vector" begin
n = 5
@model function demo(x, n=n)
m ~ MvNormal(n, 1.0)
return x ~ MvNormal(m, 1.0)
m ~ MvNormal(zeros(n), I)
return x ~ MvNormal(m, I)
end
mval = rand(n)
xval = rand(n)
Expand All @@ -49,7 +49,7 @@

names = namesingroup(chain, "m")
lps = [
logpdf(MvNormal(chain.value[i, names, j], 1.0), xval) for i in 1:size(chain, 1),
logpdf(MvNormal(chain.value[i, names, j], I), xval) for i in 1:size(chain, 1),
j in 1:size(chain, 3)
]
@test logprob"x = xval | chain = chain" == lps
Expand All @@ -71,7 +71,7 @@
σ ~ truncated(Cauchy(0, 1), 0, Inf)
α ~ filldist(Normal(0, 10), n_groups)
μ = α[group]
return y ~ MvNormal(μ, σ)
return y ~ MvNormal(μ, σ^2 * I)
end

y = randn(100)
Expand Down
2 changes: 1 addition & 1 deletion test/turing/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@
check_numerical(chain, ["p[1][1]"], [0]; atol=0.25)
end
@testset "varinfo" begin
dists = [Normal(0, 1), MvNormal([0; 0], [1.0 0; 0 1.0]), Wishart(7, [1 0.5; 0.5 1])]
dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])]
function test_varinfo!(vi)
@test getlogp(vi) === 0.0
setlogp!(vi, 1)
Expand Down
6 changes: 3 additions & 3 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@
@model function testmodel(x)
n = length(x)
s ~ truncated(Normal(), 0, Inf)
m ~ MvNormal(n, 1.0)
return x ~ MvNormal(m, s)
m ~ MvNormal(zeros(n), I)
return x ~ MvNormal(m, s^2 * I)
end

@model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV}
Expand Down Expand Up @@ -259,7 +259,7 @@

# https://github.com/TuringLang/DynamicPPL.jl/issues/250
@model function demo()
return x ~ filldist(MvNormal([1, 100], 1), 2)
return x ~ filldist(MvNormal([1, 100], I), 2)
end

vi = VarInfo(demo())
Expand Down