Skip to content

Commit dd80ad6

Browse files
author
KDr2
committed
Argument parent in constructor of Sampler
1 parent ee7726d commit dd80ad6

File tree

8 files changed

+25
-30
lines changed

8 files changed

+25
-30
lines changed

src/Turing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ Turing translates models to chunks that call the modelling functions at specifie
9191
then include that file at the end of this one.
9292
"""
9393
mutable struct Sampler{T} <: AbstractSampler
94-
parent :: AbstractSampler
9594
alg :: T
9695
info :: Dict{Symbol, Any} # sampler infomation
9796
selector :: Selector
97+
parent :: AbstractSampler
9898
end
99-
Sampler(alg, model::Model) = Sampler(alg)
100-
Sampler(alg, info::Dict{Symbol, Any}) = Sampler(SampleFromPrior(), alg, info, Selector())
99+
Sampler(alg, model::Model, parent=SampleFromPrior()) = Sampler(alg, parent)
100+
Sampler(alg, info::Dict{Symbol, Any}, parent=SampleFromPrior()) = Sampler(alg, info, Selector(), parent)
101101

102102
include("utilities/Utilities.jl")
103103
using .Utilities

src/inference/dynamichmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
3232
DynamicNUTS{AD, eltype(_space)}(n_iters, _space)
3333
end
3434

35-
function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
36-
return Sampler(alg, Dict{Symbol,Any}())
35+
function Sampler(alg::DynamicNUTS{T}, parent=SampleFromPrior()) where T <: Hamiltonian
36+
return Sampler(alg, Dict{Symbol,Any}(), parent)
3737
end
3838

3939
function sample(model::Model,

src/inference/gibbs.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin)
3333

3434
const GibbsComponent = Union{Hamiltonian,MH,PG}
3535

36-
function Sampler(alg::Gibbs, model::Model)
36+
function Sampler(alg::Gibbs, model::Model, parent=SampleFromPrior())
3737
info = Dict{Symbol, Any}()
38-
spl = Sampler(alg, info)
38+
spl = Sampler(alg, info, parent)
3939

4040
n_samplers = length(alg.algs)
4141
samplers = Array{Sampler}(undef, n_samplers)
@@ -44,8 +44,7 @@ function Sampler(alg::Gibbs, model::Model)
4444
for i in 1:n_samplers
4545
sub_alg = alg.algs[i]
4646
if isa(sub_alg, GibbsComponent)
47-
samplers[i] = Sampler(sub_alg, model)
48-
samplers[i].parent = spl
47+
samplers[i] = Sampler(sub_alg, model, spl)
4948
else
5049
@error("[Gibbs] unsupport base sampling algorithm $alg")
5150
end

src/inference/hmc.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ end
7474
DEFAULT_ADAPT_CONF_TYPE = Nothing
7575
STAN_DEFAULT_ADAPT_CONF = nothing
7676

77-
Sampler(alg::Hamiltonian) = Sampler(alg, nothing)
78-
function Sampler(alg::Hamiltonian, adapt_conf::Nothing)
79-
return _sampler(alg::Hamiltonian, adapt_conf)
77+
Sampler(alg::Hamiltonian, parent=SampleFromPrior()) = Sampler(alg, nothing, parent)
78+
function Sampler(alg::Hamiltonian, adapt_conf::Nothing, parent=SampleFromPrior())
79+
return _sampler(alg::Hamiltonian, adapt_conf, parent)
8080
end
81-
function _sampler(alg::Hamiltonian, adapt_conf)
81+
function _sampler(alg::Hamiltonian, adapt_conf, parent=SampleFromPrior())
8282
info=Dict{Symbol, Any}()
8383

8484
# For state infomation
@@ -88,7 +88,7 @@ function _sampler(alg::Hamiltonian, adapt_conf)
8888
# Adapt configuration
8989
info[:adapt_conf] = adapt_conf
9090

91-
Sampler(alg, info)
91+
Sampler(alg, info, parent)
9292
end
9393

9494
function sample(model::Model, alg::Hamiltonian;

src/inference/ipmcmc.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,20 @@ function IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int, space...)
5252
IPMCMC(n1, n2, n3, n4, resample_systematic, _space)
5353
end
5454

55-
function Sampler(alg::IPMCMC)
55+
function Sampler(alg::IPMCMC, parent=SampleFromPrior())
5656
info = Dict{Symbol, Any}()
57-
spl = Sampler(alg, info)
57+
spl = Sampler(alg, info, parent)
5858
# Create SMC and CSMC nodes
5959
samplers = Array{Sampler}(undef, alg.n_nodes)
6060
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
6161
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space)
6262
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space)
6363

6464
for i in 1:alg.n_csmc_nodes
65-
samplers[i] = Sampler(default_CSMC)
66-
samplers[i].parent = spl
65+
samplers[i] = Sampler(default_CSMC, spl)
6766
end
6867
for i in (alg.n_csmc_nodes+1):alg.n_nodes
69-
samplers[i] = Sampler(default_SMC)
70-
samplers[i].parent = spl
68+
samplers[i] = Sampler(default_SMC, spl)
7169
end
7270

7371
info[:samplers] = samplers

src/inference/mh.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ function MH(n_iters::Int, space...)
4848
MH{eltype(set)}(n_iters, proposals, set)
4949
end
5050

51-
function Sampler(alg::MH, model::Model)
51+
function Sampler(alg::MH, model::Model, parent=SampleFromPrior())
5252
alg_str = "MH"
5353

5454
# Sanity check for space
55-
# TODO if (we are going to create a top-level Sampler) && !isempty(alg.space)
56-
if false && !isempty(alg.space)
55+
if parent == SampleFromPrior() && !isempty(alg.space)
5756
@assert issubset(Set(get_pvars(model)), alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
5857
if Set(get_pvars(model)) != alg.space
5958
warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Set(get_pvars(model))))")
@@ -65,7 +64,7 @@ function Sampler(alg::MH, model::Model)
6564
info[:prior_prob] = 0.0
6665
info[:violating_support] = false
6766

68-
return Sampler(alg, info)
67+
return Sampler(alg, info, parent)
6968
end
7069

7170
function propose(model, spl::Sampler{<:MH}, vi::VarInfo)

src/inference/pgibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ end
4141

4242
const CSMC = PG # type alias of PG as Conditional SMC
4343

44-
function Sampler(alg::PG)
44+
function Sampler(alg::PG, parent=SampleFromPrior())
4545
info = Dict{Symbol, Any}()
4646
info[:logevidence] = []
47-
Sampler(alg, info)
47+
Sampler(alg, info, parent)
4848
end
4949

5050
step(model, spl::Sampler{<:PG}, vi::VarInfo, _) = step(model, spl, vi)

src/inference/pmmh.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ end
3232

3333
PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), Set())
3434

35-
function Sampler(alg::PMMH, model::Model)
35+
function Sampler(alg::PMMH, model::Model, parent=SampleFromPrior())
3636
info = Dict{Symbol, Any}()
37-
spl = Sampler(alg, info)
37+
spl = Sampler(alg, info, parent)
3838

3939
alg_str = "PMMH"
4040
n_samplers = length(alg.algs)
@@ -45,8 +45,7 @@ function Sampler(alg::PMMH, model::Model)
4545
for i in 1:n_samplers
4646
sub_alg = alg.algs[i]
4747
if isa(sub_alg, Union{SMC, MH})
48-
samplers[i] = Sampler(sub_alg, model)
49-
samplers[i].parent = spl
48+
samplers[i] = Sampler(sub_alg, model, spl)
5049
else
5150
error("[$alg_str] unsupport base sampling algorithm $alg")
5251
end

0 commit comments

Comments
 (0)