Skip to content

Commit ab424d1

Browse files
authored
Merge pull request #273 from yebai/better-warm-up
Better warm up
2 parents 0597402 + df08935 commit ab424d1

File tree

6 files changed

+131
-67
lines changed

6 files changed

+131
-67
lines changed

src/samplers/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function sample(model::Function, alg::Gibbs)
103103

104104
if isa(local_spl.alg, Hamiltonian)
105105
lp = realpart(getlogp(varInfo))
106-
epsilon = local_spl.info[][end]
106+
epsilon = local_spl.info[:wum][:ϵ][end]
107107
lf_num = local_spl.info[:lf_num]
108108
end
109109
else

src/samplers/hmc.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
# it now reuses the one of HMCDA
4242
Sampler(alg::HMC) = begin
4343
spl = Sampler(HMCDA(alg.n_iters, 0, 0.0, alg.epsilon * alg.tau, alg.space, alg.gid))
44-
spl.info[:ϵ] = [alg.epsilon]
44+
spl.info[:pre_set_ϵ] = alg.epsilon
4545
spl
4646
end
4747

@@ -58,7 +58,7 @@ Sampler(alg::Hamiltonian) = begin
5858
info[:θ_mean] = nothing
5959
info[:θ_num] = 0
6060
info[:stds] = nothing
61-
info[:θ_vars] = nothing
61+
info[:vars] = nothing
6262

6363
# For caching gradient
6464
info[:grad_cache] = Dict{Vector,Vector}()
@@ -118,7 +118,9 @@ function sample{T<:Hamiltonian}(model::Function, alg::T, chunk_size::Int)
118118
end
119119
println(" #lf / sample = $(spl.info[:total_lf_num] / n);")
120120
println(" #evals / sample = $(spl.info[:total_eval_num] / n);")
121-
println(" pre-cond. diag mat = $(spl.info[:stds]).")
121+
stds_str = string(spl.info[:wum][:stds])
122+
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
123+
println(" pre-cond. diag mat = $(stds_str).")
122124

123125
global CHUNKSIZE = default_chunk_size
124126

src/samplers/hmcda.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,23 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
5151
if is_first
5252
if spl.alg.gid != 0 link!(vi, spl) end # X -> R
5353

54-
init_pre_cond_parameters(vi, spl)
54+
init_warm_up_params(vi, spl)
5555

5656
ϵ = spl.alg.delta > 0 ?
5757
find_good_eps(model, vi, spl) : # heuristically find optimal ϵ
58-
spl.info[:ϵ][end]
58+
spl.info[:pre_set_ϵ]
5959

6060
if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X
6161

62-
init_da_parameters(spl, ϵ)
62+
update_da_params(spl.info[:wum], ϵ)
6363

6464
push!(spl.info[:accept_his], true)
6565

6666
vi
6767
else
6868
# Set parameters
6969
λ = spl.alg.lambda
70-
ϵ = spl.info[][end]; dprintln(2, "current ϵ: ")
70+
ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: ")
7171

7272
spl.info[:lf_num] = 0 # reset current lf num counter
7373

@@ -95,15 +95,14 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
9595
α = min(1, exp(-(H - old_H)))
9696

9797
if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook.
98+
stds_str = string(spl.info[:wum][:stds])
99+
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
98100
haskey(spl.info, :progress) && ProgressMeter.update!(
99101
spl.info[:progress],
100-
spl.info[:progress].counter; showvalues = [(, ϵ), (, α), (:pre_cond, spl.info[:stds])]
102+
spl.info[:progress].counter; showvalues = [(, ϵ), (, α), (:pre_cond, stds_str)]
101103
)
102104
end
103105

104-
# Use Dual Averaging to adapt ϵ
105-
adapt_step_size(spl, α, spl.alg.delta)
106-
107106
dprintln(2, "decide wether to accept...")
108107
if rand() < α # accepted
109108
push!(spl.info[:accept_his], true)
@@ -113,8 +112,8 @@ function step(model, spl::Sampler{HMCDA}, vi::VarInfo, is_first::Bool)
113112
setlogp!(vi, old_logp) # reset logp
114113
end
115114

116-
# Update pre-conditioning matrix
117-
update_pre_cond(vi, spl)
115+
# Adapt step-size and pre-cond
116+
adapt(spl.info[:wum], α, realpart(vi[spl]))
118117

119118
dprintln(3, "R -> X...")
120119
if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end

src/samplers/nuts.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,20 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
4646
if is_first
4747
if spl.alg.gid != 0 link!(vi, spl) end # X -> R
4848

49-
init_pre_cond_parameters(vi, spl)
49+
init_warm_up_params(vi, spl)
5050

5151
ϵ = find_good_eps(model, vi, spl) # heuristically find optimal ϵ
5252

5353
if spl.alg.gid != 0 invlink!(vi, spl) end # R -> X
5454

55-
init_da_parameters(spl, ϵ)
55+
update_da_params(spl.info[:wum], ϵ)
5656

5757
push!(spl.info[:accept_his], true)
5858

5959
vi
6060
else
6161
# Set parameters
62-
ϵ = spl.info[][end]; dprintln(2, "current ϵ: ")
62+
ϵ = spl.info[:wum][:ϵ][end]; dprintln(2, "current ϵ: ")
6363

6464
spl.info[:lf_num] = 0 # reset current lf num counter
6565

@@ -93,9 +93,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
9393
end
9494

9595
if ~(isdefined(Main, :IJulia) && Main.IJulia.inited) # Fix for Jupyter notebook.
96-
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info[:progress],
97-
spl.info[:progress].counter;
98-
showvalues = [(, ϵ), (:tree_depth, j)])
96+
stds_str = string(spl.info[:wum][:stds])
97+
stds_str = length(stds_str) >= 32 ? stds_str[1:30]*"..." : stds_str
98+
haskey(spl.info, :progress) && ProgressMeter.update!(
99+
spl.info[:progress],
100+
spl.info[:progress].counter; showvalues = [(, ϵ), (:tree_depth, j), (:pre_cond, stds_str)]
101+
)
99102
end
100103

101104
if s′ == 1 && rand() < min(1, n′ / n)
@@ -108,15 +111,12 @@ function step(model::Function, spl::Sampler{NUTS}, vi::VarInfo, is_first::Bool)
108111
j = j + 1
109112
end
110113

111-
# Use Dual Averaging to adapt ϵ
112-
adapt_step_size(spl, α / n_α, spl.alg.delta)
113-
114114
push!(spl.info[:accept_his], true)
115115
vi[spl] = θ
116116
setlogp!(vi, logp)
117117

118-
# Update pre-conditioning matrix
119-
update_pre_cond(vi, spl)
118+
# Adapt step-size and pre-cond
119+
adapt(spl.info[:wum], α / n_α, realpart(vi[spl]))
120120

121121
dprintln(3, "R -> X...")
122122
if spl.alg.gid != 0 invlink!(vi, spl); cleandual!(vi) end

src/samplers/support/adapt.jl

Lines changed: 103 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,123 @@
1-
adapt_step_size{T<:Hamiltonian}(spl::Sampler{T}, stats::Float64, δ::Float64) = begin
1+
type WarmUpManager
2+
iter_n :: Int
3+
state :: Int
4+
params :: Dict
5+
end
6+
7+
getindex(wum::WarmUpManager, param) = wum.params[param]
8+
9+
setindex!(wum::WarmUpManager, value, param) = wum.params[param] = value
10+
11+
init_warm_up_params{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
12+
wum = WarmUpManager(1, 1, Dict())
13+
14+
# Pre-cond
15+
wum[:θ_num] = 0
16+
wum[:θ_mean] = nothing
17+
D = length(vi[spl])
18+
wum[:stds] = ones(D)
19+
wum[:vars] = ones(D)
20+
21+
# DA
22+
wum[] = nothing
23+
wum[] = nothing
24+
wum[:ϵ_bar] = 1.0
25+
wum[:H_bar] = 0.0
26+
wum[:m] = 0
27+
wum[:n_adapt] = spl.alg.n_adapt
28+
wum[] = spl.alg.delta
29+
30+
spl.info[:wum] = wum
31+
end
32+
33+
update_da_params(wum::WarmUpManager, ϵ::Float64) = begin
34+
wum[] = [ϵ]
35+
wum[] = log(10 * ϵ)
36+
end
37+
38+
adapt_step_size(wum::WarmUpManager, stats::Float64) = begin
239
dprintln(2, "adapting step size ϵ...")
3-
m = spl.info[:m] += 1
4-
if m <= spl.alg.n_adapt
5-
γ = 0.05; t_0 = 10; κ = 0.75
6-
μ = spl.info[]; ϵ_bar = spl.info[:ϵ_bar]; H_bar = spl.info[:H_bar]
40+
m = wum[:m] += 1
41+
if m <= wum[:n_adapt]
42+
γ = 0.05; t_0 = 10; κ = 0.75; δ = wum[]
43+
μ = wum[]; ϵ_bar = wum[:ϵ_bar]; H_bar = wum[:H_bar]
744

845
H_bar = (1 - 1 / (m + t_0)) * H_bar + 1 / (m + t_0) *- stats)
946
ϵ = exp- sqrt(m) / γ * H_bar)
1047
dprintln(1, " ϵ = , stats = $stats")
1148

1249
ϵ_bar = exp(m^(-κ) * log(ϵ) + (1 - m^(-κ)) * log(ϵ_bar))
13-
push!(spl.info[], ϵ)
14-
spl.info[:ϵ_bar], spl.info[:H_bar] = ϵ_bar, H_bar
50+
push!(wum[], ϵ)
51+
wum[:ϵ_bar], wum[:H_bar] = ϵ_bar, H_bar
1552

16-
if m == spl.alg.n_adapt
53+
if m == wum[:n_adapt]
1754
dprintln(0, " Adapted ϵ = , $m HMC iterations is used for adaption.")
1855
end
1956
end
2057
end
2158

22-
init_da_parameters{T<:Hamiltonian}(spl::Sampler{T}, ϵ::Float64) = begin
23-
spl.info[] = [ϵ]
24-
spl.info[] = log(10 * ϵ)
25-
spl.info[:ϵ_bar] = 1.0
26-
spl.info[:H_bar] = 0.0
27-
spl.info[:m] = 0
28-
end
59+
update_pre_cond(wum::WarmUpManager, θ_new) = begin
60+
61+
wum[:θ_num] += 1 # θ_new = x_t
62+
t = wum[:θ_num] # t
2963

30-
update_pre_cond{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
31-
θ_new = realpart(vi[spl]) # x_t
32-
spl.info[:θ_num] += 1
33-
t = spl.info[:θ_num] # t
34-
θ_mean_old = copy(spl.info[:θ_mean]) # x_bar_t-1
35-
spl.info[:θ_mean] = (t - 1) / t * spl.info[:θ_mean] + θ_new / t # x_bar_t
36-
θ_mean_new = spl.info[:θ_mean] # x_bar_t
37-
38-
if t == 2
39-
first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ
40-
spl.info[:θ_vars] = diag(cov(first_two))
41-
elseif t <= 1000
42-
D = length(θ_new)
43-
# D = 2.4^2
44-
spl.info[:θ_vars] = (t - 1) / t * spl.info[:θ_vars] .+ 100 * eps(Float64) +
45-
(2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new)
64+
if t == 1
65+
wum[:θ_mean] = θ_new
66+
else
67+
θ_mean_old = copy(wum[:θ_mean]) # x_bar_t-1
68+
wum[:θ_mean] = (t - 1) / t * wum[:θ_mean] + θ_new / t # x_bar_t
69+
θ_mean_new = wum[:θ_mean] # x_bar_t
70+
71+
if t == 2
72+
first_two = [θ_mean_old'; θ_new'] # θ_mean_old here only contains the first θ
73+
wum[:vars] = diag(cov(first_two))
74+
else#if t <= 1000
75+
D = length(θ_new)
76+
# D = 2.4^2
77+
wum[:vars] = (t - 1) / t * wum[:vars] .+ 100 * eps(Float64) +
78+
(2.4^2 / D) / t * (t * θ_mean_old .* θ_mean_old - (t + 1) * θ_mean_new .* θ_mean_new + θ_new .* θ_new)
79+
end
80+
81+
if t > 100
82+
wum[:stds] = sqrt(wum[:vars])
83+
wum[:stds] = wum[:stds] / min(wum[:stds]...)
84+
end
4685
end
86+
end
87+
88+
update_state(wum::WarmUpManager) = begin
89+
wum.iter_n += 1 # update iteration number
4790

48-
if t > 500
49-
spl.info[:stds] = sqrt(spl.info[:θ_vars])
50-
spl.info[:stds] = spl.info[:stds] / min(spl.info[:stds]...)
91+
# Update state
92+
if wum.state == 1
93+
if wum.iter_n > 100
94+
wum.state = 2
95+
end
96+
elseif wum.state == 2
97+
if wum.iter_n > 900
98+
wum.state = 3
99+
end
100+
elseif wum.state == 3
101+
if wum.iter_n > 1000
102+
wum.state = 4
103+
end
104+
elseif wum.state == 4
105+
# no more change
106+
else
107+
error("[Turing.WarmUpManager] unknown state $(wum.state)")
51108
end
52109
end
53110

54-
init_pre_cond_parameters{T<:Hamiltonian}(vi::VarInfo, spl::Sampler{T}) = begin
55-
spl.info[:θ_mean] = realpart(vi[spl])
56-
spl.info[:θ_num] = 1
57-
D = length(vi[spl])
58-
spl.info[:stds] = ones(D)
59-
spl.info[:θ_vars] = nothing
111+
adapt(wum::WarmUpManager, stats::Float64, θ_new) = begin
112+
update_state(wum)
113+
114+
# Use Dual Averaging to adapt ϵ
115+
if wum.state in [1, 2, 3]
116+
adapt_step_size(wum, stats)
117+
end
118+
119+
# Update pre-conditioning matrix
120+
if wum.state == 2
121+
update_pre_cond(wum, θ_new)
122+
end
60123
end

src/samplers/support/hmc_core.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ end
1515

1616
sample_momentum(vi::VarInfo, spl::Sampler) = begin
1717
dprintln(2, "sampling momentum...")
18-
randn(length(getranges(vi, spl))) .* spl.info[:stds]
18+
randn(length(getranges(vi, spl))) .* spl.info[:wum][:stds]
1919
end
2020

2121
# Leapfrog step
@@ -86,7 +86,7 @@ find_H(p::Vector, model::Function, vi::VarInfo, spl::Sampler) = begin
8686
# This can be a result of link/invlink (where expand! is used)
8787
if getlogp(vi) == 0 vi = runmodel(model, vi, spl) end
8888

89-
p_orig = p ./ spl.info[:stds]
89+
p_orig = p ./ spl.info[:wum][:stds]
9090

9191
H = dot(p_orig, p_orig) / 2 + realpart(-getlogp(vi))
9292
if isnan(H) H = Inf else H end

0 commit comments

Comments
 (0)