diff --git a/src/samplers/gibbs.jl b/src/samplers/gibbs.jl index 60a4534f86..fd793cc6b9 100644 --- a/src/samplers/gibbs.jl +++ b/src/samplers/gibbs.jl @@ -103,7 +103,7 @@ function sample(model::Function, alg::Gibbs) if isa(local_spl.alg, Hamiltonian) lp = realpart(getlogp(varInfo)) - epsilon = local_spl.info[:ϵ][end] + epsilon = local_spl.info[:wum][:ϵ][end] lf_num = local_spl.info[:lf_num] end else diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index ce77d3e058..6e75f4f606 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -41,7 +41,7 @@ end # it now reuses the one of HMCDA Sampler(alg::HMC) = begin spl = Sampler(HMCDA(alg.n_iters, 0, 0.0, alg.epsilon * alg.tau, alg.space, alg.gid)) - spl.info[:ϵ] = [alg.epsilon] + spl.info[:pre_set_ϵ] = alg.epsilon spl end @@ -58,7 +58,7 @@ Sampler(alg::Hamiltonian) = begin info[:θ_mean] = nothing info[:θ_num] = 0 info[:stds] = nothing - info[:θ_vars] = nothing + info[:vars] = nothing # For caching gradient info[:grad_cache] = Dict{Vector,Vector}() @@ -118,7 +118,9 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int) end println(" #lf / sample = $(spl.info[:total_lf_num] / n);") println(" #evals / sample = $(spl.info[:total_eval_num] / n);") - println(" pre-cond. diag mat = $(spl.info[:stds]).") + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str + println(" pre-cond. diag mat = $(stds_str).") global CHUNKSIZE = default_chunk_size diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index 0d8b69d7a2..5bfb042c7e 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -51,15 +51,15 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) if is_first if spl.alg.gid != 0 link!(vi, spl) end # X -> R - init_pre_cond_parameters(vi, spl) + init_warm_up_params(vi, spl) ϵ = spl.alg.delta > 0 ? find_good_eps(model, vi, spl) : # heuristically find optimal ϵ - spl.info[:ϵ][end] + spl.info[:pre_set_ϵ] if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - init_da_parameters(spl, ϵ) + update_da_params(spl.info[:wum], ϵ) push!(spl.info[:accept_his], true) @@ -67,7 +67,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) else # Set parameters λ = spl.alg.lambda - ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ") + ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ") spl.info[:lf_num] = 0 # reset current lf num counter @@ -95,15 +95,14 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) α = min(1, exp(-(H - old_H))) if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook. + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str haskey(spl.info, :progress) && ProgressMeter.update!( spl.info[:progress], - spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, spl.info[:stds])] + spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, stds_str)] ) end - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl, α, spl.alg.delta) - dprintln(2, "decide wether to accept...") if rand() < α # accepted push!(spl.info[:accept_his], true) @@ -113,8 +112,8 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) setlogp!(vi, old_logp) # reset logp end - # Update pre-conditioning matrix - update_pre_cond(vi, spl) + # Adapt step-size and pre-cond + adapt(spl.info[:wum], α, realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index f3c0124d11..aeff2c3424 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -46,20 +46,20 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) if is_first if spl.alg.gid != 0 link!(vi, spl) end # X -> R - init_pre_cond_parameters(vi, spl) + init_warm_up_params(vi, spl) ϵ = find_good_eps(model, vi, spl) # heuristically find optimal ϵ if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - init_da_parameters(spl, ϵ) + update_da_params(spl.info[:wum], ϵ) push!(spl.info[:accept_his], true) vi else # Set parameters - ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ") + ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ") spl.info[:lf_num] = 0 # reset current lf num counter @@ -93,9 +93,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) end if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook. - haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress], - spl.info[:progress].counter; - showvalues = [(:ϵ, ϵ), (:tree_depth, j)]) + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str + haskey(spl.info, :progress) && ProgressMeter.update!( + spl.info[:progress], + spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:tree_depth, j), (:pre_cond, stds_str)] + ) end if s′ == 1 && rand() < min(1, n′ / n) @@ -108,15 +111,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) j = j + 1 end - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl, α / n_α, spl.alg.delta) - push!(spl.info[:accept_his], true) vi[spl] = θ setlogp!(vi, logp) - # Update pre-conditioning matrix - update_pre_cond(vi, spl) + # Adapt step-size and pre-cond + adapt(spl.info[:wum], α / n_α, realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/sampler.jl b/src/samplers/sampler.jl index b96c8f88a3..09fdef07e3 100644 --- a/src/samplers/sampler.jl +++ b/src/samplers/sampler.jl @@ -7,6 +7,7 @@ include("support/resample.jl") end include("support/hmc_core.jl") include("support/adapt.jl") +include("support/init.jl") include("hmcda.jl") include("nuts.jl") include("hmc.jl") @@ -32,7 +33,7 @@ assume(spl::Void, dist::Distribution, vn::VarName, vi::VarInfo) = begin if haskey(vi, vn) r = vi[vn] else - r = rand(dist) + r = init(dist) push!(vi, vn, r, dist, 0) end acclogp!(vi, logpdf(dist, r, istrans(vi, vn))) diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index a6565427a7..3b604057db 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -1,60 +1,123 @@ -adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = begin +type WarmUpManager + iter_n :: Int + state :: Int + params :: Dict +end + +getindex(wum::WarmUpManager, param) = wum.params[param] + +setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value + +init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin + wum = WarmUpManager(1, 1, Dict()) + + # Pre-cond + wum[:θ_num] = 0 + wum[:θ_mean] = nothing + D = length(vi[spl]) + wum[:stds] = ones(D) + wum[:vars] = ones(D) + + # DA + wum[:ϵ] = nothing + wum[:μ] = nothing + wum[:ϵ_bar] = 1.0 + wum[:H_bar] = 0.0 + wum[:m] = 0 + wum[:n_adapt] = spl.alg.n_adapt + wum[:δ] = spl.alg.delta + + spl.info[:wum] = wum +end + +update_da_params(wum::WarmUpManager, ϵ::Float64) = begin + wum[:ϵ] = [ϵ] + wum[:μ] = log(10 * ϵ) +end + +adapt_step_size(wum::WarmUpManager, stats::Float64) = begin dprintln(2, "adapting step size ϵ...") - m = spl.info[:m] += 1 - if m <= spl.alg.n_adapt - γ = 0.05; t_0 = 10; κ = 0.75 - μ = spl.info[:μ]; ϵ_bar = spl.info[:ϵ_bar]; H_bar = spl.info[:H_bar] + m = wum[:m] += 1 + if m <= wum[:n_adapt] + γ = 0.05; t_0 = 10; κ = 0.75; δ = wum[:δ] + μ = wum[:μ]; ϵ_bar = wum[:ϵ_bar]; H_bar = wum[:H_bar] H_bar = (1 - 1 / (m + t_0)) * H_bar + 1 / (m + t_0) * (δ - stats) ϵ = exp(μ - sqrt(m) / γ * H_bar) dprintln(1, " ϵ = $ϵ, stats = $stats") ϵ_bar = exp(m^(-κ) * log(ϵ) + (1 - m^(-κ)) * log(ϵ_bar)) - push!(spl.info[:ϵ], ϵ) - spl.info[:ϵ_bar], spl.info[:H_bar] = ϵ_bar, H_bar + push!(wum[:ϵ], ϵ) + wum[:ϵ_bar], wum[:H_bar] = ϵ_bar, H_bar - if m == spl.alg.n_adapt + if m == wum[:n_adapt] dprintln(0, " Adapted ϵ = $ϵ, $m HMC iterations is used for adaption.") end end end -init_da_parameters{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin - spl.info[:ϵ] = [ϵ] - spl.info[:μ] = log(10 * ϵ) - spl.info[:ϵ_bar] = 1.0 - spl.info[:H_bar] = 0.0 - spl.info[:m] = 0 -end +update_pre_cond(wum::WarmUpManager, θ_new) = begin + + wum[:θ_num] += 1 # θ_new = x_t + t = wum[:θ_num] # t -update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin - θ_new = realpart(vi[spl]) # x_t - spl.info[:θ_num] += 1 - t = spl.info[:θ_num] # t - θ_mean_old = copy(spl.info[:θ_mean]) # x_bar_t-1 - spl.info[:θ_mean] = (t - 1) / t * spl.info[:θ_mean] + θ_new / t # x_bar_t - θ_mean_new = spl.info[:θ_mean] # x_bar_t - - if t == 2 - first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ - spl.info[:θ_vars] = diag(cov(first_two)) - elseif t <= 1000 - D = length(θ_new) - # D = 2.4^2 - spl.info[:θ_vars] = (t - 1) / t * spl.info[:θ_vars] .+ 100 * eps(Float64) + - (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) + if t == 1 + wum[:θ_mean] = θ_new + else + θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1 + wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t + θ_mean_new = wum[:θ_mean] # x_bar_t + + if t == 2 + first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ + wum[:vars] = diag(cov(first_two)) + else#if t <= 1000 + D = length(θ_new) + # D = 2.4^2 + wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) + + (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) + end + + if t > 100 + wum[:stds] = sqrt(wum[:vars]) + wum[:stds] = wum[:stds] / min(wum[:stds]...) + end end +end + +update_state(wum::WarmUpManager) = begin + wum.iter_n += 1 # update iteration number - if t > 500 - spl.info[:stds] = sqrt(spl.info[:θ_vars]) - spl.info[:stds] = spl.info[:stds] / min(spl.info[:stds]...) + # Update state + if wum.state == 1 + if wum.iter_n > 100 + wum.state = 2 + end + elseif wum.state == 2 + if wum.iter_n > 900 + wum.state = 3 + end + elseif wum.state == 3 + if wum.iter_n > 1000 + wum.state = 4 + end + elseif wum.state == 4 + # no more change + else + error("[Turing.WarmUpManager] unknown state $(wum.state)") end end -init_pre_cond_parameters{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin - spl.info[:θ_mean] = realpart(vi[spl]) - spl.info[:θ_num] = 1 - D = length(vi[spl]) - spl.info[:stds] = ones(D) - spl.info[:θ_vars] = nothing +adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin + update_state(wum) + + # Use Dual Averaging to adapt ϵ + if wum.state in [1, 2, 3] + adapt_step_size(wum, stats) + end + + # Update pre-conditioning matrix + if wum.state == 2 + update_pre_cond(wum, θ_new) + end end diff --git a/src/samplers/support/hmc_core.jl b/src/samplers/support/hmc_core.jl index 1c99e7fb9f..c4bcff17b9 100644 --- a/src/samplers/support/hmc_core.jl +++ b/src/samplers/support/hmc_core.jl @@ -15,7 +15,7 @@ end sample_momentum(vi::VarInfo, spl::Sampler) = begin dprintln(2, "sampling momentum...") - randn(length(getranges(vi, spl))) .* spl.info[:stds] + randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds] end # Leapfrog step @@ -86,7 +86,7 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin # This can be a result of link/invlink (where expand! is used) if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end - p_orig = p ./ spl.info[:stds] + p_orig = p ./ spl.info[:wum][:stds] H = dot(p_orig, p_orig) / 2 + realpart(-getlogp(vi)) if isnan(H) H = Inf else H end diff --git a/src/samplers/support/init.jl b/src/samplers/support/init.jl new file mode 100644 index 0000000000..b380b1b66a --- /dev/null +++ b/src/samplers/support/init.jl @@ -0,0 +1,42 @@ +# Only use customized initialization for transformable distributions +init(dist::TransformDistribution) = inittrans(dist) + +# Callbacks for un-transformable distributions +init(dist::Distribution) = rand(dist) + +# Uniform rand with range +randuni() = e * rand() # may Euler's number give us good luck + +inittrans(dist::UnivariateDistribution) = begin + r = Real(randuni()) + + r = invlink(dist, r) + + r +end + +inittrans(dist::MultivariateDistribution) = begin + D = size(dist)[1] + + r = Vector{Real}(D) + for d = 1:D + r[d] = randuni() + end + + r = invlink(dist, r) + + r +end + +inittrans(dist::MatrixDistribution) = begin + D = size(dist) + + r = Matrix{Real}(D...) + for d1 = 1:D, d2 = 1:D + r[d1,d2] = randuni() + end + + r = invlink(dist, r) + + r +end