Skip to content

Commit 22e5dd1

Browse files
committed
More test fixes
1 parent 7c32e3e commit 22e5dd1

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

src/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ function setparams_varinfo!!(
438438
state::TuringState,
439439
params::AbstractVarInfo,
440440
)
441-
logdensity = DynamicPPL.setmodel(state.ldf, model, sampler.alg.adtype)
441+
logdensity = DynamicPPL.LogDensityFunction(model; adtype=sampler.alg.adtype)
442442
new_inner_state = setparams_varinfo!!(
443443
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
444444
)

src/mcmc/sghmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function DynamicPPL.initialstep(
7272
DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext());
7373
adtype=spl.alg.adtype,
7474
)
75-
state = SGHMCState(ℓ, vi, zero(vi[spl]))
75+
state = SGHMCState(ℓ, vi, zero(vi[:]))
7676

7777
return sample, state
7878
end
@@ -87,7 +87,7 @@ function AbstractMCMC.step(
8787
# Compute gradient of log density.
8888
= state.logdensity
8989
vi = state.vi
90-
θ = vi[spl]
90+
θ = vi[:]
9191
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
9292

9393
# Update latent variables and velocity according to
@@ -246,7 +246,7 @@ function AbstractMCMC.step(
246246
# Perform gradient step.
247247
= state.logdensity
248248
vi = state.vi
249-
θ = vi[spl]
249+
θ = vi[:]
250250
grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ))
251251
step = state.step
252252
stepsize = spl.alg.stepsize(step)

test/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ using Turing
512512

513513
@model function vdemo2(x)
514514
μ ~ MvNormal(zeros(size(x, 1)), I)
515-
return x .~ MvNormal(μ, I)
515+
return x ~ filldist(MvNormal(μ, I), size(x, 2))
516516
end
517517

518518
D = 2
@@ -560,7 +560,7 @@ using Turing
560560

561561
@model function vdemo7()
562562
x = Array{Real}(undef, N, N)
563-
return x .~ [InverseGamma(2, 3) for i in 1:N]
563+
return x ~ filldist(InverseGamma(2, 3), N, N)
564564
end
565565

566566
sample(StableRNG(seed), vdemo7(), alg, 10)

test/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ using Turing
218218
# https://github.com/TuringLang/Turing.jl/issues/1308
219219
@model function mwe3(::Type{T}=Array{Float64}) where {T}
220220
m = T(undef, 2, 3)
221-
return m .~ MvNormal(zeros(2), I)
221+
return m ~ filldist(MvNormal(zeros(2), I), 3)
222222
end
223223
@test sample(StableRNG(seed), mwe3(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains
224224
end

test/mcmc/mh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
238238

239239
# Link if proposal is `AdvancedHM.RandomWalkProposal`
240240
vi = deepcopy(vi_base)
241-
d = length(vi_base[DynamicPPL.SampleFromPrior()])
241+
d = length(vi_base[:])
242242
alg = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I)))
243243
spl = DynamicPPL.Sampler(alg)
244244
vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default)

0 commit comments

Comments
 (0)