From 1bd7b41eec39a1711de23f579bc6a9a48ff2e7d2 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 29 May 2017 17:00:55 +0100 Subject: [PATCH 1/7] Better initialization #219 --- src/samplers/sampler.jl | 3 ++- src/samplers/support/init.jl | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/samplers/support/init.jl diff --git a/src/samplers/sampler.jl b/src/samplers/sampler.jl index b96c8f88a3..09fdef07e3 100644 --- a/src/samplers/sampler.jl +++ b/src/samplers/sampler.jl @@ -7,6 +7,7 @@ include("support/resample.jl") end include("support/hmc_core.jl") include("support/adapt.jl") +include("support/init.jl") include("hmcda.jl") include("nuts.jl") include("hmc.jl") @@ -32,7 +33,7 @@ assume(spl::Void, dist::Distribution, vn::VarName, vi::VarInfo) = begin if haskey(vi, vn) r = vi[vn] else - r = rand(dist) + r = init(dist) push!(vi, vn, r, dist, 0) end acclogp!(vi, logpdf(dist, r, istrans(vi, vn))) diff --git a/src/samplers/support/init.jl b/src/samplers/support/init.jl new file mode 100644 index 0000000000..b380b1b66a --- /dev/null +++ b/src/samplers/support/init.jl @@ -0,0 +1,42 @@ +# Only use customized initialization for transformable distributions +init(dist::TransformDistribution) = inittrans(dist) + +# Callbacks for un-transformable distributions +init(dist::Distribution) = rand(dist) + +# Uniform rand with range +randuni() = e * rand() # may Euler's number give us good luck + +inittrans(dist::UnivariateDistribution) = begin + r = Real(randuni()) + + r = invlink(dist, r) + + r +end + +inittrans(dist::MultivariateDistribution) = begin + D = size(dist)[1] + + r = Vector{Real}(D) + for d = 1:D + r[d] = randuni() + end + + r = invlink(dist, r) + + r +end + +inittrans(dist::MatrixDistribution) = begin + D = size(dist) + + r = Matrix{Real}(D...) + for d1 = 1:D, d2 = 1:D + r[d1,d2] = randuni() + end + + r = invlink(dist, r) + + r +end From c7982d97c7190e72795a1e0de531e5fbea5d9c1a Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Mon, 29 May 2017 22:57:31 +0100 Subject: [PATCH 2/7] Unify warm-up init --- src/samplers/hmcda.jl | 4 +-- src/samplers/nuts.jl | 4 +-- src/samplers/support/adapt.jl | 46 ++++++++++++++++++++++++----------- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index 0d8b69d7a2..e856c23f70 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -51,7 +51,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) if is_first if spl.alg.gid != 0 link!(vi, spl) end # X -> R - init_pre_cond_parameters(vi, spl) + init_warm_up_params(vi, spl) ϵ = spl.alg.delta > 0 ? find_good_eps(model, vi, spl) : # heuristically find optimal ϵ @@ -59,7 +59,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - init_da_parameters(spl, ϵ) + update_da_params(spl, ϵ) push!(spl.info[:accept_his], true) diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index f3c0124d11..2ac336a234 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -46,13 +46,13 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) if is_first if spl.alg.gid != 0 link!(vi, spl) end # X -> R - init_pre_cond_parameters(vi, spl) + init_warm_up_params(vi, spl) ϵ = find_good_eps(model, vi, spl) # heuristically find optimal ϵ if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - init_da_parameters(spl, ϵ) + update_da_params(spl, ϵ) push!(spl.info[:accept_his], true) diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index a6565427a7..9e24eaa9e8 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -1,3 +1,25 @@ +init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin + # Pre-cond + spl.info[:θ_mean] = realpart(vi[spl]) + spl.info[:θ_num] = 1 + D = length(vi[spl]) + spl.info[:stds] = ones(D) + spl.info[:θ_vars] = ones(D) + # DA + if ~haskey(spl.info, :ϵ) + spl.info[:ϵ] = nothing + end + spl.info[:μ] = nothing + spl.info[:ϵ_bar] = 1.0 + spl.info[:H_bar] = 0.0 + spl.info[:m] = 0 +end + +update_da_params{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin + spl.info[:ϵ] = [ϵ] + spl.info[:μ] = log(10 * ϵ) +end + adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = begin dprintln(2, "adapting step size ϵ...") m = spl.info[:m] += 1 @@ -19,14 +41,6 @@ adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = end end -init_da_parameters{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin - spl.info[:ϵ] = [ϵ] - spl.info[:μ] = log(10 * ϵ) - spl.info[:ϵ_bar] = 1.0 - spl.info[:H_bar] = 0.0 - spl.info[:m] = 0 -end - update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin θ_new = realpart(vi[spl]) # x_t spl.info[:θ_num] += 1 @@ -51,10 +65,14 @@ update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin end end -init_pre_cond_parameters{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin - spl.info[:θ_mean] = realpart(vi[spl]) - spl.info[:θ_num] = 1 - D = length(vi[spl]) - spl.info[:stds] = ones(D) - spl.info[:θ_vars] = nothing + + +type WarmUpManager + state :: Int + curr_iter :: Int + info :: Dict +end + +update_state(wum::WarmUpManager) = begin + end From 7a6e30f043ed607cf85d8893bfda5dbf61629c50 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 May 2017 13:11:54 +0100 Subject: [PATCH 3/7] Add a warm-up manager --- src/samplers/gibbs.jl | 2 +- src/samplers/hmc.jl | 4 +- src/samplers/hmcda.jl | 10 ++-- src/samplers/nuts.jl | 8 +-- src/samplers/support/adapt.jl | 94 +++++++++++++++++--------------- src/samplers/support/hmc_core.jl | 4 +- 6 files changed, 64 insertions(+), 58 deletions(-) diff --git a/src/samplers/gibbs.jl b/src/samplers/gibbs.jl index 60a4534f86..fd793cc6b9 100644 --- a/src/samplers/gibbs.jl +++ b/src/samplers/gibbs.jl @@ -103,7 +103,7 @@ function sample(model::Function, alg::Gibbs) if isa(local_spl.alg, Hamiltonian) lp = realpart(getlogp(varInfo)) - epsilon = local_spl.info[:ϵ][end] + epsilon = local_spl.info[:wum][:ϵ][end] lf_num = local_spl.info[:lf_num] end else diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index ce77d3e058..8270bbe1af 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -41,7 +41,7 @@ end # it now reuses the one of HMCDA Sampler(alg::HMC) = begin spl = Sampler(HMCDA(alg.n_iters, 0, 0.0, alg.epsilon * alg.tau, alg.space, alg.gid)) - spl.info[:ϵ] = [alg.epsilon] + spl.info[:pre_set_ϵ] = alg.epsilon spl end @@ -58,7 +58,7 @@ Sampler(alg::Hamiltonian) = begin info[:θ_mean] = nothing info[:θ_num] = 0 info[:stds] = nothing - info[:θ_vars] = nothing + info[:vars] = nothing # For caching gradient info[:grad_cache] = Dict{Vector,Vector}() diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index e856c23f70..9d31de4a06 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -55,11 +55,11 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) ϵ = spl.alg.delta > 0 ? find_good_eps(model, vi, spl) : # heuristically find optimal ϵ - spl.info[:ϵ][end] + spl.info[:pre_set_ϵ] if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - update_da_params(spl, ϵ) + update_da_params(spl.info[:wum], ϵ) push!(spl.info[:accept_his], true) @@ -67,7 +67,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) else # Set parameters λ = spl.alg.lambda - ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ") + ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ") spl.info[:lf_num] = 0 # reset current lf num counter @@ -102,7 +102,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) end # Use Dual Averaging to adapt ϵ - adapt_step_size(spl, α, spl.alg.delta) + adapt_step_size(spl.info[:wum], α, spl.alg.delta) dprintln(2, "decide wether to accept...") if rand() < α # accepted @@ -114,7 +114,7 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) end # Update pre-conditioning matrix - update_pre_cond(vi, spl) + update_pre_cond(spl.info[:wum], realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index 2ac336a234..738e8f0de8 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -52,14 +52,14 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X - update_da_params(spl, ϵ) + update_da_params(spl.info[:wum], ϵ) push!(spl.info[:accept_his], true) vi else # Set parameters - ϵ = spl.info[:ϵ][end]; dprintln(2, "current ϵ: $ϵ") + ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: $ϵ") spl.info[:lf_num] = 0 # reset current lf num counter @@ -109,14 +109,14 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) end # Use Dual Averaging to adapt ϵ - adapt_step_size(spl, α / n_α, spl.alg.delta) + adapt_step_size(spl.info[:wum], α / n_α, spl.alg.delta) push!(spl.info[:accept_his], true) vi[spl] = θ setlogp!(vi, logp) # Update pre-conditioning matrix - update_pre_cond(vi, spl) + update_pre_cond(spl.info[:wum], realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index 9e24eaa9e8..c1db027a43 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -1,78 +1,84 @@ +type WarmUpManager + state :: Int + curr_iter :: Int + params :: Dict +end + +getindex(wum::WarmUpManager, param) = wum.params[param] + +setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value + +update_state(wum::WarmUpManager) = begin + +end + init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin + wum = WarmUpManager(1, 1, Dict()) + # Pre-cond - spl.info[:θ_mean] = realpart(vi[spl]) - spl.info[:θ_num] = 1 + wum[:θ_num] = 1 + wum[:θ_mean] = realpart(vi[spl]) D = length(vi[spl]) - spl.info[:stds] = ones(D) - spl.info[:θ_vars] = ones(D) + wum[:stds] = ones(D) + wum[:vars] = ones(D) + # DA - if ~haskey(spl.info, :ϵ) - spl.info[:ϵ] = nothing - end - spl.info[:μ] = nothing - spl.info[:ϵ_bar] = 1.0 - spl.info[:H_bar] = 0.0 - spl.info[:m] = 0 + wum[:ϵ] = nothing + wum[:μ] = nothing + wum[:ϵ_bar] = 1.0 + wum[:H_bar] = 0.0 + wum[:m] = 0 + wum[:n_adapt] = spl.alg.n_adapt + + spl.info[:wum] = wum end -update_da_params{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin - spl.info[:ϵ] = [ϵ] - spl.info[:μ] = log(10 * ϵ) +update_da_params(wum::WarmUpManager, ϵ::Float64) = begin + wum[:ϵ] = [ϵ] + wum[:μ] = log(10 * ϵ) end -adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = begin +adapt_step_size(wum::WarmUpManager, stats::Float64, δ::Float64) = begin dprintln(2, "adapting step size ϵ...") - m = spl.info[:m] += 1 - if m <= spl.alg.n_adapt + m = wum[:m] += 1 + if m <= wum[:n_adapt] γ = 0.05; t_0 = 10; κ = 0.75 - μ = spl.info[:μ]; ϵ_bar = spl.info[:ϵ_bar]; H_bar = spl.info[:H_bar] + μ = wum[:μ]; ϵ_bar = wum[:ϵ_bar]; H_bar = wum[:H_bar] H_bar = (1 - 1 / (m + t_0)) * H_bar + 1 / (m + t_0) * (δ - stats) ϵ = exp(μ - sqrt(m) / γ * H_bar) dprintln(1, " ϵ = $ϵ, stats = $stats") ϵ_bar = exp(m^(-κ) * log(ϵ) + (1 - m^(-κ)) * log(ϵ_bar)) - push!(spl.info[:ϵ], ϵ) - spl.info[:ϵ_bar], spl.info[:H_bar] = ϵ_bar, H_bar + push!(wum[:ϵ], ϵ) + wum[:ϵ_bar], wum[:H_bar] = ϵ_bar, H_bar - if m == spl.alg.n_adapt + if m == wum[:n_adapt] dprintln(0, " Adapted ϵ = $ϵ, $m HMC iterations is used for adaption.") end end end -update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin - θ_new = realpart(vi[spl]) # x_t - spl.info[:θ_num] += 1 - t = spl.info[:θ_num] # t - θ_mean_old = copy(spl.info[:θ_mean]) # x_bar_t-1 - spl.info[:θ_mean] = (t - 1) / t * spl.info[:θ_mean] + θ_new / t # x_bar_t - θ_mean_new = spl.info[:θ_mean] # x_bar_t +update_pre_cond(wum::WarmUpManager, θ_new) = begin + + wum[:θ_num] += 1 # θ_new = x_t + t = wum[:θ_num] # t + θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1 + wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t + θ_mean_new = wum[:θ_mean] # x_bar_t if t == 2 first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ - spl.info[:θ_vars] = diag(cov(first_two)) + wum[:vars] = diag(cov(first_two)) elseif t <= 1000 D = length(θ_new) # D = 2.4^2 - spl.info[:θ_vars] = (t - 1) / t * spl.info[:θ_vars] .+ 100 * eps(Float64) + + wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) + (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) end if t > 500 - spl.info[:stds] = sqrt(spl.info[:θ_vars]) - spl.info[:stds] = spl.info[:stds] / min(spl.info[:stds]...) + wum[:stds] = sqrt(wum[:vars]) + wum[:stds] = wum[:stds] / min(wum[:stds]...) end end - - - -type WarmUpManager - state :: Int - curr_iter :: Int - info :: Dict -end - -update_state(wum::WarmUpManager) = begin - -end diff --git a/src/samplers/support/hmc_core.jl b/src/samplers/support/hmc_core.jl index 1c99e7fb9f..c4bcff17b9 100644 --- a/src/samplers/support/hmc_core.jl +++ b/src/samplers/support/hmc_core.jl @@ -15,7 +15,7 @@ end sample_momentum(vi::VarInfo, spl::Sampler) = begin dprintln(2, "sampling momentum...") - randn(length(getranges(vi, spl))) .* spl.info[:stds] + randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds] end # Leapfrog step @@ -86,7 +86,7 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin # This can be a result of link/invlink (where expand! is used) if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end - p_orig = p ./ spl.info[:stds] + p_orig = p ./ spl.info[:wum][:stds] H = dot(p_orig, p_orig) / 2 + realpart(-getlogp(vi)) if isnan(H) H = Inf else H end From 9237ede1e065c0f4e7c94ff1ad089e38510e8d3c Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 May 2017 13:15:50 +0100 Subject: [PATCH 4/7] Move adapt functions to the same place --- src/samplers/hmcda.jl | 6 +++--- src/samplers/nuts.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index 9d31de4a06..e0369e581f 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -101,9 +101,6 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) ) end - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl.info[:wum], α, spl.alg.delta) - dprintln(2, "decide wether to accept...") if rand() < α # accepted push!(spl.info[:accept_his], true) @@ -113,6 +110,9 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) setlogp!(vi, old_logp) # reset logp end + # Use Dual Averaging to adapt ϵ + adapt_step_size(spl.info[:wum], α, spl.alg.delta) + # Update pre-conditioning matrix update_pre_cond(spl.info[:wum], realpart(vi[spl])) diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index 738e8f0de8..68ab5bd1fa 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -108,13 +108,13 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) j = j + 1 end - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl.info[:wum], α / n_α, spl.alg.delta) - push!(spl.info[:accept_his], true) vi[spl] = θ setlogp!(vi, logp) + # Use Dual Averaging to adapt ϵ + adapt_step_size(spl.info[:wum], α / n_α, spl.alg.delta) + # Update pre-conditioning matrix update_pre_cond(spl.info[:wum], realpart(vi[spl])) From a59537e3bc31d62ff566ef806b9365fa3588150c Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 May 2017 13:52:42 +0100 Subject: [PATCH 5/7] Finish wum state update --- src/samplers/hmc.jl | 4 +++- src/samplers/hmcda.jl | 7 ++---- src/samplers/nuts.jl | 7 ++---- src/samplers/support/adapt.jl | 44 +++++++++++++++++++++++++++++++---- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index 8270bbe1af..96a137f999 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -118,7 +118,9 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int) end println(" #lf / sample = $(spl.info[:total_lf_num] / n);") println(" #evals / sample = $(spl.info[:total_eval_num] / n);") - println(" pre-cond. diag mat = $(spl.info[:stds]).") + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 16 ? stds_str[1:14]*"..." : stds_str + println(" pre-cond. diag mat = $stds_str.") global CHUNKSIZE = default_chunk_size diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index e0369e581f..016f693f3e 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -110,11 +110,8 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) setlogp!(vi, old_logp) # reset logp end - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl.info[:wum], α, spl.alg.delta) - - # Update pre-conditioning matrix - update_pre_cond(spl.info[:wum], realpart(vi[spl])) + # Adapt step-size and pre-cond + adapt(spl.info[:wum], α, realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index 68ab5bd1fa..c3f0611d76 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -112,11 +112,8 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) vi[spl] = θ setlogp!(vi, logp) - # Use Dual Averaging to adapt ϵ - adapt_step_size(spl.info[:wum], α / n_α, spl.alg.delta) - - # Update pre-conditioning matrix - update_pre_cond(spl.info[:wum], realpart(vi[spl])) + # Adapt step-size and pre-cond + adapt(spl.info[:wum], α / n_α, realpart(vi[spl])) dprintln(3, "R -> X...") if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index c1db027a43..e53d744375 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -1,7 +1,7 @@ type WarmUpManager - state :: Int - curr_iter :: Int - params :: Dict + iter_n :: Int + state :: Int + params :: Dict end getindex(wum::WarmUpManager, param) = wum.params[param] @@ -9,7 +9,26 @@ getindex(wum::WarmUpManager, param) = wum.params[param] setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value update_state(wum::WarmUpManager) = begin + wum.iter_n += 1 # update iteration number + # Update state + if wum.state == 1 + if wum.iter_n >= 100 + wum.state = 2 + end + elseif wum.state == 2 + if wum.iter_n >= 900 + wum.state = 3 + end + elseif wum.state == 3 + if wum.iter_n >= 1000 + wum.state = 4 + end + elseif wum.state == 4 + # no more change + else + error("[Turing.WarmUpManager] unknown state $(wum.state)") + end end init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin @@ -29,6 +48,7 @@ init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin wum[:H_bar] = 0.0 wum[:m] = 0 wum[:n_adapt] = spl.alg.n_adapt + wum[:δ] = spl.alg.delta spl.info[:wum] = wum end @@ -38,11 +58,11 @@ update_da_params(wum::WarmUpManager, ϵ::Float64) = begin wum[:μ] = log(10 * ϵ) end -adapt_step_size(wum::WarmUpManager, stats::Float64, δ::Float64) = begin +adapt_step_size(wum::WarmUpManager, stats::Float64) = begin dprintln(2, "adapting step size ϵ...") m = wum[:m] += 1 if m <= wum[:n_adapt] - γ = 0.05; t_0 = 10; κ = 0.75 + γ = 0.05; t_0 = 10; κ = 0.75; δ = wum[:δ] μ = wum[:μ]; ϵ_bar = wum[:ϵ_bar]; H_bar = wum[:H_bar] H_bar = (1 - 1 / (m + t_0)) * H_bar + 1 / (m + t_0) * (δ - stats) @@ -82,3 +102,17 @@ update_pre_cond(wum::WarmUpManager, θ_new) = begin wum[:stds] = wum[:stds] / min(wum[:stds]...) end end + +adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin + update_state(wum) + + # Use Dual Averaging to adapt ϵ + # if wum.state in [1, 2, 3] + adapt_step_size(wum, stats) + # end + + # Update pre-conditioning matrix + # if wum.state == 2 + update_pre_cond(wum, θ_new) + # end +end From b1dd86e8ec68c5c393e4a37fecfa3ea625ab86c6 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 May 2017 13:58:32 +0100 Subject: [PATCH 6/7] Make init of pre-cond into loop function --- src/samplers/hmc.jl | 2 +- src/samplers/support/adapt.jl | 33 +++++++++++++++++++-------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index 96a137f999..5309b0b9f4 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -119,7 +119,7 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int) println(" #lf / sample = $(spl.info[:total_lf_num] / n);") println(" #evals / sample = $(spl.info[:total_eval_num] / n);") stds_str = string(spl.info[:wum][:stds]) - stds_str = length(stds_str) >= 16 ? stds_str[1:14]*"..." : stds_str + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str println(" pre-cond. diag mat = $stds_str.") global CHUNKSIZE = default_chunk_size diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index e53d744375..83d4c36c7b 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -35,8 +35,8 @@ init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin wum = WarmUpManager(1, 1, Dict()) # Pre-cond - wum[:θ_num] = 1 - wum[:θ_mean] = realpart(vi[spl]) + wum[:θ_num] = 0 + wum[:θ_mean] = nothing D = length(vi[spl]) wum[:stds] = ones(D) wum[:vars] = ones(D) @@ -83,18 +83,23 @@ update_pre_cond(wum::WarmUpManager, θ_new) = begin wum[:θ_num] += 1 # θ_new = x_t t = wum[:θ_num] # t - θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1 - wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t - θ_mean_new = wum[:θ_mean] # x_bar_t - - if t == 2 - first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ - wum[:vars] = diag(cov(first_two)) - elseif t <= 1000 - D = length(θ_new) - # D = 2.4^2 - wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) + - (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) + + if t == 1 + wum[:θ_mean] = θ_new + else + θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1 + wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t + θ_mean_new = wum[:θ_mean] # x_bar_t + + if t == 2 + first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ + wum[:vars] = diag(cov(first_two)) + elseif t <= 1000 + D = length(θ_new) + # D = 2.4^2 + wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) + + (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) + end end if t > 500 From df08935cead591190dd1168e2c4ccf4439aad6aa Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Tue, 30 May 2017 14:15:48 +0100 Subject: [PATCH 7/7] Initial support of windowed warm-up #265 --- src/samplers/hmc.jl | 2 +- src/samplers/hmcda.jl | 4 ++- src/samplers/nuts.jl | 9 +++-- src/samplers/support/adapt.jl | 62 +++++++++++++++++------------------ 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/samplers/hmc.jl b/src/samplers/hmc.jl index 5309b0b9f4..6e75f4f606 100644 --- a/src/samplers/hmc.jl +++ b/src/samplers/hmc.jl @@ -120,7 +120,7 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int) println(" #evals / sample = $(spl.info[:total_eval_num] / n);") stds_str = string(spl.info[:wum][:stds]) stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str - println(" pre-cond. diag mat = $stds_str.") + println(" pre-cond. diag mat = $(stds_str).") global CHUNKSIZE = default_chunk_size diff --git a/src/samplers/hmcda.jl b/src/samplers/hmcda.jl index 016f693f3e..5bfb042c7e 100644 --- a/src/samplers/hmcda.jl +++ b/src/samplers/hmcda.jl @@ -95,9 +95,11 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool) α = min(1, exp(-(H - old_H))) if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook. + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str haskey(spl.info, :progress) && ProgressMeter.update!( spl.info[:progress], - spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, spl.info[:stds])] + spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:α, α), (:pre_cond, stds_str)] ) end diff --git a/src/samplers/nuts.jl b/src/samplers/nuts.jl index c3f0611d76..aeff2c3424 100644 --- a/src/samplers/nuts.jl +++ b/src/samplers/nuts.jl @@ -93,9 +93,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool) end if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook. - haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress], - spl.info[:progress].counter; - showvalues = [(:ϵ, ϵ), (:tree_depth, j)]) + stds_str = string(spl.info[:wum][:stds]) + stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str + haskey(spl.info, :progress) && ProgressMeter.update!( + spl.info[:progress], + spl.info[:progress].counter; showvalues = [(:ϵ, ϵ), (:tree_depth, j), (:pre_cond, stds_str)] + ) end if s′ == 1 && rand() < min(1, n′ / n) diff --git a/src/samplers/support/adapt.jl b/src/samplers/support/adapt.jl index 83d4c36c7b..3b604057db 100644 --- a/src/samplers/support/adapt.jl +++ b/src/samplers/support/adapt.jl @@ -8,29 +8,6 @@ getindex(wum::WarmUpManager, param) = wum.params[param] setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value -update_state(wum::WarmUpManager) = begin - wum.iter_n += 1 # update iteration number - - # Update state - if wum.state == 1 - if wum.iter_n >= 100 - wum.state = 2 - end - elseif wum.state == 2 - if wum.iter_n >= 900 - wum.state = 3 - end - elseif wum.state == 3 - if wum.iter_n >= 1000 - wum.state = 4 - end - elseif wum.state == 4 - # no more change - else - error("[Turing.WarmUpManager] unknown state $(wum.state)") - end -end - init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin wum = WarmUpManager(1, 1, Dict()) @@ -94,17 +71,40 @@ update_pre_cond(wum::WarmUpManager, θ_new) = begin if t == 2 first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ wum[:vars] = diag(cov(first_two)) - elseif t <= 1000 + else#if t <= 1000 D = length(θ_new) # D = 2.4^2 wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) + (2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new) end + + if t > 100 + wum[:stds] = sqrt(wum[:vars]) + wum[:stds] = wum[:stds] / min(wum[:stds]...) + end end +end + +update_state(wum::WarmUpManager) = begin + wum.iter_n += 1 # update iteration number - if t > 500 - wum[:stds] = sqrt(wum[:vars]) - wum[:stds] = wum[:stds] / min(wum[:stds]...) + # Update state + if wum.state == 1 + if wum.iter_n > 100 + wum.state = 2 + end + elseif wum.state == 2 + if wum.iter_n > 900 + wum.state = 3 + end + elseif wum.state == 3 + if wum.iter_n > 1000 + wum.state = 4 + end + elseif wum.state == 4 + # no more change + else + error("[Turing.WarmUpManager] unknown state $(wum.state)") end end @@ -112,12 +112,12 @@ adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin update_state(wum) # Use Dual Averaging to adapt ϵ - # if wum.state in [1, 2, 3] + if wum.state in [1, 2, 3] adapt_step_size(wum, stats) - # end + end # Update pre-conditioning matrix - # if wum.state == 2 + if wum.state == 2 update_pre_cond(wum, θ_new) - # end + end end