-
Notifications
You must be signed in to change notification settings - Fork 230
Description
Lately I've been experimenting with various ways to speed up the Turing model.I tried the @distributed method but got completely wrong results after using it.
One of my questions is how to get the correct results using the @distributed method?Another question is is there any other way to get the results quickly and correctly? I've tried Threads.@threads but the speed increase is not good enough.
Running the code below I can get the same result as the initial definition of the mean of the Poisson distribution.
using Distributions,Turing,Random
using Distributed
n=3000
μ=(rand(Uniform(20,40),10))
10-element Vector{Float64}:
30.256065112894298
26.014315408019847
24.61075823480178
21.519220815830373
24.423829409052153
31.93931799311691
36.88338733245446
23.069268268746683
27.68725276509452
22.156260959934805
rd=rand.(Poisson.(μ),n)
@model function Turing_tele(y)
# Our prior belief about the probability of heads in a coin.
ρ ~ filldist(Gamma(6.5,2.25),10) #
for i = 1:10
y[i] ~ Poisson(ρ[i])
end
end
model = Turing_tele(rd)
Random.seed!(1)
@time chn = sample(model, NUTS(),1000)
Info: Found initial step size
└ ϵ = 0.00625
16.337205 seconds (23.63 M allocations: 1.696 GiB, 3.03% gc time, 38.45% compilation time)
Chains MCMC chain (1000×22×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 13.49 seconds
Compute duration = 13.49 seconds
parameters = ρ[1], ρ[2], ρ[3], ρ[4], ρ[5], ρ[6], ρ[7], ρ[8], ρ[9], ρ[10]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
ρ[1] 30.4083 0.1023 0.0026 1601.3122 928.1514 1.0008 118.6861
ρ[2] 25.9982 0.0919 0.0027 1230.7776 867.4562 1.0015 91.2228
ρ[3] 24.6462 0.0927 0.0028 1137.7679 706.6938 1.0031 84.3291
ρ[4] 21.5981 0.0860 0.0029 879.0077 765.2233 1.0019 65.1503
ρ[5] 24.3673 0.0912 0.0028 1067.3277 843.5914 1.0006 79.1082
ρ[6] 32.0114 0.0992 0.0029 1153.5239 818.6120 1.0014 85.4969
ρ[7] 36.7634 0.1093 0.0027 1677.7044 812.3265 1.0005 124.3481
ρ[8] 23.1246 0.0927 0.0024 1452.7873 684.0625 0.9994 107.6777
ρ[9] 27.6116 0.0966 0.0027 1301.1715 852.6950 0.9996 96.4402
ρ[10] 22.2285 0.0864 0.0025 1239.3571 842.7230 1.0018 91.8587
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
ρ[1] 30.2058 30.3386 30.4098 30.4800 30.6128
ρ[2] 25.8195 25.9341 26.0001 26.0641 26.1756
ρ[3] 24.4712 24.5808 24.6493 24.7103 24.8221
ρ[4] 21.4327 21.5428 21.5973 21.6572 21.7586
ρ[5] 24.1889 24.3095 24.3667 24.4213 24.5552
ρ[6] 31.8253 31.9456 32.0105 32.0761 32.2037
ρ[7] 36.5529 36.6878 36.7621 36.8360 36.9851
ρ[8] 22.9364 23.0670 23.1255 23.1863 23.2879
ρ[9] 27.4253 27.5458 27.6089 27.6797 27.7924
ρ[10] 22.0631 22.1699 22.2254 22.2879 22.4102
But if I use @distributed before the loop in the Turing model, the result is completely wrong
@model function Turing_tele(y)
# Our prior belief about the probability of heads in a coin.
ρ ~ filldist(Gamma(6.5,2.25),10) #
@distributed for i = 1:10
y[i] ~ Poisson(ρ[i])
end
end
Info: Found initial step size
└ ϵ = 1.6
11.582992 seconds (4.06 M allocations: 425.341 MiB, 0.37% gc time, 9.47% compilation time: 73% of which was recompilation)
Chains MCMC chain (1000×22×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 11.34 seconds
Compute duration = 11.34 seconds
parameters = ρ[1], ρ[2], ρ[3], ρ[4], ρ[5], ρ[6], ρ[7], ρ[8], ρ[9], ρ[10]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
ρ[1] 14.2881 5.5633 0.1703 1022.7573 578.0672 0.9999 90.1664
ρ[2] 14.7304 5.7350 0.1581 1220.2044 765.2927 0.9991 107.5733
ρ[3] 14.8755 5.7867 0.1630 1227.5404 740.2345 1.0013 108.2201
ρ[4] 14.5879 5.7901 0.1513 1388.8694 464.7315 0.9991 122.4429
ρ[5] 14.3013 6.2120 0.1503 1514.4644 655.4396 1.0006 133.5153
ρ[6] 14.6386 5.7405 0.1530 1358.4567 614.7594 0.9992 119.7617
ρ[7] 14.6543 5.8039 0.1597 1229.2604 641.8933 1.0000 108.3717
ρ[8] 14.3672 5.5997 0.1624 1124.2711 640.0229 1.0024 99.1159
ρ[9] 14.5488 5.5751 0.1531 1155.3912 524.0374 1.0014 101.8594
ρ[10] 14.5977 5.9391 0.1516 1441.3284 636.2089 1.0091 127.0677