diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index e447bd2fc8..a75779e73e 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -438,7 +438,9 @@ function setparams_varinfo!!( state::TuringState, params::AbstractVarInfo, ) - logdensity = DynamicPPL.setmodel(state.ldf, model, sampler.alg.adtype) + logdensity = DynamicPPL.LogDensityFunction( + model, state.ldf.varinfo, state.ldf.context; adtype=sampler.alg.adtype + ) new_inner_state = setparams_varinfo!!( AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params ) diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 9a6124e14a..0c322244eb 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -72,7 +72,7 @@ function DynamicPPL.initialstep( DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext()); adtype=spl.alg.adtype, ) - state = SGHMCState(ℓ, vi, zero(vi[spl])) + state = SGHMCState(ℓ, vi, zero(vi[:])) return sample, state end @@ -87,7 +87,7 @@ function AbstractMCMC.step( # Compute gradient of log density. ℓ = state.logdensity vi = state.vi - θ = vi[spl] + θ = vi[:] grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) # Update latent variables and velocity according to @@ -246,7 +246,7 @@ function AbstractMCMC.step( # Perform gradient step. ℓ = state.logdensity vi = state.vi - θ = vi[spl] + θ = vi[:] grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step stepsize = spl.alg.stepsize(step) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 36ff7cc81e..fac3a2f402 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -512,7 +512,7 @@ using Turing @model function vdemo2(x) μ ~ MvNormal(zeros(size(x, 1)), I) - return x .~ MvNormal(μ, I) + return x ~ filldist(MvNormal(μ, I), size(x, 2)) end D = 2 @@ -560,7 +560,7 @@ using Turing @model function vdemo7() x = Array{Real}(undef, N, N) - return x .~ [InverseGamma(2, 3) for i in 1:N] + return x ~ filldist(InverseGamma(2, 3), N, N) end sample(StableRNG(seed), vdemo7(), alg, 10) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index d802dc1db8..718b3cfe3f 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -218,7 +218,7 @@ using Turing # https://github.com/TuringLang/Turing.jl/issues/1308 @model function mwe3(::Type{T}=Array{Float64}) where {T} m = T(undef, 2, 3) - return m .~ MvNormal(zeros(2), I) + return m ~ filldist(MvNormal(zeros(2), I), 3) end @test sample(StableRNG(seed), mwe3(), HMC(0.2, 4; adtype=adbackend), 100) isa Chains end diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index 2d133d3693..d190e589a5 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -238,7 +238,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var)) # Link if proposal is `AdvancedHM.RandomWalkProposal` vi = deepcopy(vi_base) - d = length(vi_base[DynamicPPL.SampleFromPrior()]) + d = length(vi_base[:]) alg = MH(AdvancedMH.RandomWalkProposal(MvNormal(zeros(d), I))) spl = DynamicPPL.Sampler(alg) vi = Turing.Inference.maybe_link!!(vi, spl, alg.proposals, gdemo_default)