Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/samplers/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function sample(model::Function, alg::Gibbs)

if isa(local_spl.alg, Hamiltonian)
lp = realpart(getlogp(varInfo))
epsilon = local_spl.info[][end]
epsilon = local_spl.info[:wum][:ϵ][end]
lf_num = local_spl.info[:lf_num]
end
else
Expand Down
8 changes: 5 additions & 3 deletions src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
# it now reuses the one of HMCDA
Sampler(alg::HMC) = begin
spl = Sampler(HMCDA(alg.n_iters, 0, 0.0, alg.epsilon * alg.tau, alg.space, alg.gid))
spl.info[:ϵ] = [alg.epsilon]
spl.info[:pre_set_ϵ] = alg.epsilon
spl
end

Expand All @@ -58,7 +58,7 @@ Sampler(alg::Hamiltonian) = begin
info[:θ_mean] = nothing
info[:θ_num] = 0
info[:stds] = nothing
info[:θ_vars] = nothing
info[:vars] = nothing

# For caching gradient
info[:grad_cache] = Dict{Vector,Vector}()
Expand Down Expand Up @@ -118,7 +118,9 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int)
end
println(" #lf / sample = $(spl.info[:total_lf_num] / n);")
println(" #evals / sample = $(spl.info[:total_eval_num] / n);")
println(" pre-cond. diag mat = $(spl.info[:stds]).")
stds_str = string(spl.info[:wum][:stds])
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
println(" pre-cond. diag mat = $(stds_str).")

global CHUNKSIZE = default_chunk_size

Expand Down
19 changes: 9 additions & 10 deletions src/samplers/hmcda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,23 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
if is_first
if spl.alg.gid != 0 link!(vi, spl) end # X -> R

init_pre_cond_parameters(vi, spl)
init_warm_up_params(vi, spl)

ϵ = spl.alg.delta > 0 ?
find_good_eps(model, vi, spl) : # heuristically find optimal ϵ
spl.info[:ϵ][end]
spl.info[:pre_set_ϵ]

if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X

init_da_parameters(spl, ϵ)
update_da_params(spl.info[:wum], ϵ)

push!(spl.info[:accept_his], true)

vi
else
# Set parameters
λ = spl.alg.lambda
ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ")
ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ")

spl.info[:lf_num] = 0 # reset current lf num counter

Expand Down Expand Up @@ -95,15 +95,14 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
α = min(1, exp(-(H - old_H)))

if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook.
stds_str = string(spl.info[:wum][:stds])
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
haskey(spl.info, :progress) && ProgressMeter.update!(
spl.info[:progress],
spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, spl.info[:stds])]
spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, stds_str)]
)
end

# Use Dual Averaging to adapt ϵ
adapt_step_size(spl, α, spl.alg.delta)

dprintln(2, "decide wether to accept...")
if rand() < α # accepted
push!(spl.info[:accept_his], true)
Expand All @@ -113,8 +112,8 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
setlogp!(vi, old_logp) # reset logp
end

# Update pre-conditioning matrix
update_pre_cond(vi, spl)
# Adapt step-size and pre-cond
adapt(spl.info[:wum], α, realpart(vi[spl]))

dprintln(3, "R -> X...")
if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end
Expand Down
22 changes: 11 additions & 11 deletions src/samplers/nuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
if is_first
if spl.alg.gid != 0 link!(vi, spl) end # X -> R

init_pre_cond_parameters(vi, spl)
init_warm_up_params(vi, spl)

ϵ = find_good_eps(model, vi, spl) # heuristically find optimal ϵ

if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X

init_da_parameters(spl, ϵ)
update_da_params(spl.info[:wum], ϵ)

push!(spl.info[:accept_his], true)

vi
else
# Set parameters
ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ")
ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ")

spl.info[:lf_num] = 0 # reset current lf num counter

Expand Down Expand Up @@ -93,9 +93,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
end

if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook.
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress],
spl.info[:progress].counter;
showvalues = [(:ϵ, ϵ), (:tree_depth, j)])
stds_str = string(spl.info[:wum][:stds])
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
haskey(spl.info, :progress) && ProgressMeter.update!(
spl.info[:progress],
spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:tree_depth, j), (:pre_cond, stds_str)]
)
end

if s′ == 1 && rand() < min(1, n′ / n)
Expand All @@ -108,15 +111,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
j = j + 1
end

# Use Dual Averaging to adapt ϵ
adapt_step_size(spl, α / n_α, spl.alg.delta)

push!(spl.info[:accept_his], true)
vi[spl] = θ
setlogp!(vi, logp)

# Update pre-conditioning matrix
update_pre_cond(vi, spl)
# Adapt step-size and pre-cond
adapt(spl.info[:wum], α / n_α, realpart(vi[spl]))

dprintln(3, "R -> X...")
if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end
Expand Down
3 changes: 2 additions & 1 deletion src/samplers/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include("support/resample.jl")
end
include("support/hmc_core.jl")
include("support/adapt.jl")
include("support/init.jl")
include("hmcda.jl")
include("nuts.jl")
include("hmc.jl")
Expand All @@ -32,7 +33,7 @@ assume(spl::Void, dist::Distribution, vn::VarName, vi::VarInfo) = begin
if haskey(vi, vn)
r = vi[vn]
else
r = rand(dist)
r = init(dist)
push!(vi, vn, r, dist, 0)
end
acclogp!(vi, logpdf(dist, r, istrans(vi, vn)))
Expand Down
143 changes: 103 additions & 40 deletions src/samplers/support/adapt.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,123 @@
adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = begin
type WarmUpManager
iter_n :: Int
state :: Int
params :: Dict
end

getindex(wum::WarmUpManager, param) = wum.params[param]

setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value

init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
wum = WarmUpManager(1, 1, Dict())

# Pre-cond
wum[:θ_num] = 0
wum[:θ_mean] = nothing
D = length(vi[spl])
wum[:stds] = ones(D)
wum[:vars] = ones(D)

# DA
wum[:ϵ] = nothing
wum[:μ] = nothing
wum[:ϵ_bar] = 1.0
wum[:H_bar] = 0.0
wum[:m] = 0
wum[:n_adapt] = spl.alg.n_adapt
wum[:δ] = spl.alg.delta

spl.info[:wum] = wum
end

update_da_params(wum::WarmUpManager, ϵ::Float64) = begin
wum[:ϵ] = [ϵ]
wum[:μ] = log(10 * ϵ)
end

adapt_step_size(wum::WarmUpManager, stats::Float64) = begin
dprintln(2, "adapting step size ϵ...")
m = spl.info[:m] += 1
if m <= spl.alg.n_adapt
γ = 0.05; t_0 = 10; κ = 0.75
μ = spl.info[:μ]; ϵ_bar = spl.info[:ϵ_bar]; H_bar = spl.info[:H_bar]
m = wum[:m] += 1
if m <= wum[:n_adapt]
γ = 0.05; t_0 = 10; κ = 0.75; δ = wum[:δ]
μ = wum[:μ]; ϵ_bar = wum[:ϵ_bar]; H_bar = wum[:H_bar]

H_bar = (1 - 1 / (m + t_0)) * H_bar + 1 / (m + t_0) * (δ - stats)
ϵ = exp(μ - sqrt(m) / γ * H_bar)
dprintln(1, " ϵ = $ϵ, stats = $stats")

ϵ_bar = exp(m^(-κ) * log(ϵ) + (1 - m^(-κ)) * log(ϵ_bar))
push!(spl.info[:ϵ], ϵ)
spl.info[:ϵ_bar], spl.info[:H_bar] = ϵ_bar, H_bar
push!(wum[:ϵ], ϵ)
wum[:ϵ_bar], wum[:H_bar] = ϵ_bar, H_bar

if m == spl.alg.n_adapt
if m == wum[:n_adapt]
dprintln(0, " Adapted ϵ = $ϵ, $m HMC iterations is used for adaption.")
end
end
end

init_da_parameters{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin
spl.info[:ϵ] = [ϵ]
spl.info[:μ] = log(10 * ϵ)
spl.info[:ϵ_bar] = 1.0
spl.info[:H_bar] = 0.0
spl.info[:m] = 0
end
update_pre_cond(wum::WarmUpManager, θ_new) = begin

wum[:θ_num] += 1 # θ_new = x_t
t = wum[:θ_num] # t

update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
θ_new = realpart(vi[spl]) # x_t
spl.info[:θ_num] += 1
t = spl.info[:θ_num] # t
θ_mean_old = copy(spl.info[:θ_mean]) # x_bar_t-1
spl.info[:θ_mean] = (t - 1) / t * spl.info[:θ_mean] + θ_new / t # x_bar_t
θ_mean_new = spl.info[:θ_mean] # x_bar_t

if t == 2
first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ
spl.info[:θ_vars] = diag(cov(first_two))
elseif t <= 1000
D = length(θ_new)
# D = 2.4^2
spl.info[:θ_vars] = (t - 1) / t * spl.info[:θ_vars] .+ 100 * eps(Float64) +
(2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new)
if t == 1
wum[:θ_mean] = θ_new
else
θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1
wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t
θ_mean_new = wum[:θ_mean] # x_bar_t

if t == 2
first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ
wum[:vars] = diag(cov(first_two))
else#if t <= 1000
D = length(θ_new)
# D = 2.4^2
wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) +
(2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new)
end

if t > 100
wum[:stds] = sqrt(wum[:vars])
wum[:stds] = wum[:stds] / min(wum[:stds]...)
end
end
end

update_state(wum::WarmUpManager) = begin
wum.iter_n += 1 # update iteration number

if t > 500
spl.info[:stds] = sqrt(spl.info[:θ_vars])
spl.info[:stds] = spl.info[:stds] / min(spl.info[:stds]...)
# Update state
if wum.state == 1
if wum.iter_n > 100
wum.state = 2
end
elseif wum.state == 2
if wum.iter_n > 900
wum.state = 3
end
elseif wum.state == 3
if wum.iter_n > 1000
wum.state = 4
end
elseif wum.state == 4
# no more change
else
error("[Turing.WarmUpManager] unknown state $(wum.state)")
end
end

init_pre_cond_parameters{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
spl.info[:θ_mean] = realpart(vi[spl])
spl.info[:θ_num] = 1
D = length(vi[spl])
spl.info[:stds] = ones(D)
spl.info[:θ_vars] = nothing
adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin
update_state(wum)

# Use Dual Averaging to adapt ϵ
if wum.state in [1, 2, 3]
adapt_step_size(wum, stats)
end

# Update pre-conditioning matrix
if wum.state == 2
update_pre_cond(wum, θ_new)
end
end
4 changes: 2 additions & 2 deletions src/samplers/support/hmc_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end

sample_momentum(vi::VarInfo, spl::Sampler) = begin
dprintln(2, "sampling momentum...")
randn(length(getranges(vi, spl))) .* spl.info[:stds]
randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds]
end

# Leapfrog step
Expand Down Expand Up @@ -86,7 +86,7 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin
# This can be a result of link/invlink (where expand! is used)
if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end

p_orig = p ./ spl.info[:stds]
p_orig = p ./ spl.info[:wum][:stds]

H = dot(p_orig, p_orig) / 2 + realpart(-getlogp(vi))
if isnan(H) H = Inf else H end
Expand Down
42 changes: 42 additions & 0 deletions src/samplers/support/init.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Only use customized initialization for transformable distributions
init(dist::TransformDistribution) = inittrans(dist)

# Callbacks for un-transformable distributions
init(dist::Distribution) = rand(dist)

# Uniform rand with range
randuni() = e * rand() # may Euler's number give us good luck

inittrans(dist::UnivariateDistribution) = begin
r = Real(randuni())

r = invlink(dist, r)

r
end

inittrans(dist::MultivariateDistribution) = begin
D = size(dist)[1]

r = Vector{Real}(D)
for d = 1:D
r[d] = randuni()
end

r = invlink(dist, r)

r
end

inittrans(dist::MatrixDistribution) = begin
D = size(dist)

r = Matrix{Real}(D...)
for d1 = 1:D, d2 = 1:D
r[d1,d2] = randuni()
end

r = invlink(dist, r)

r
end