diff --git a/docs/src/using-turing/interface.md b/docs/src/using-turing/interface.md index 24e999953a..868f4ca8b6 100644 --- a/docs/src/using-turing/interface.md +++ b/docs/src/using-turing/interface.md @@ -100,9 +100,6 @@ Note that we only have to do this because we are not yet integrating the sampler struct DensityModel{F<:Function} <: AbstractModel ℓπ :: F end - -# Default density constructor. -DensityModel(π::Function, data::T) where T = DensityModel{VariateForm, ValueSupport, T}(π, data) ``` ### Transition @@ -118,7 +115,7 @@ struct Transition{T<:Union{Vector{<:Real}, <:Real}, L<:Real} <: AbstractTransiti end # Store the new draw and its log density. -Transition(model::M, θ::T) where {M<:DensityModel, T} = Transition(θ, ℓπ(model, θ)) +Transition(model::DensityModel, θ) = Transition(θ, ℓπ(model, θ)) ``` `Transition` can now store any type of parameter, whether it's a vector of draws from multiple parameters or a single univariate draw. We should also tell the interface what specific subtype of `AbstractTransition` we're going to use, so we can just define a new method on `transition_type`: @@ -198,7 +195,7 @@ q(spl::MetropolisHastings, θ::Vector{<:Real}, θcond::Vector{<:Real}) = logpdf( q(spl::MetropolisHastings, t1::Transition, t2::Transition) = q(spl, t1.θ, t2.θ) # Calculate the density of the model given some parameterization. -ℓπ(model::DensityModel, θ::T) where T = model.ℓπ(θ) +ℓπ(model::DensityModel, θ) = model.ℓπ(θ) ℓπ(model::DensityModel, t::Transition) = t.lp # Define the other step function. Returns a Transition containing @@ -235,13 +232,13 @@ The last piece in our puzzle is a `bundle_samples` function, which accepts a `Ve # A basic chains constructor that works with the Transition struct we defined. function bundle_samples( rng::AbstractRNG, - ℓ::DensityModel, + ℓ::AbstractModel, s::MetropolisHastings, N::Integer, - ts::Vector{T}; + ts::Vector{<:AbstractTransition}; param_names=missing, kwargs... -) where {ModelType<:AbstractModel, T<:AbstractTransition} +) # Turn all the transitions into a vector-of-vectors. vals = copy(reduce(hcat,[vcat(t.θ, t.lp) for t in ts])')