Skip to content

[RFC] Sampler states #31

@devmotion

Description

@devmotion

I think the current approach of dealing with states of samplers, e.g., in Turing is flawed. Currently, one defines an initial state when initializing the sampler. This is problematic since often the concrete types of the states are not known during initialization which in turn then requires setting up dummy states and states with abstractly typed fields which potentially impacts performance.

With the current API of AbstractMCMC it is possible to avoid this issue by making the sampler state a part of the transition generated by step! and then implementing transitions_init and transitions_save! such that only the part of the transitions without the sampler state is actually saved (as done for the Gibbs sampler in Turing and proposed for the DynamicNUTS sampler in TuringLang/Turing.jl#1186). Moreover, since only the samples + statistics are stored, it is impossible to resume sampling since the state of the sampler is not known anymore.

So as mentioned in TuringLang/Turing.jl#1186, one could switch to a setup with (more or less) stateless samplers and explicit states, something along the lines of:

function mcmcsample(
    rng::Random.AbstractRNG,
    model::AbstractModel,
    sampler::AbstractSampler,
	N::Integer;
    progress = true,
    progressname = "Sampling",
    callback = (args...; kwargs...) -> nothing,
    chain_type::Type=Any,
    kwargs...
)
    # Check the number of requested samples.
    N > 0 || error("the number of samples must be ≥ 1")

    @ifwithprogresslogger progress name=progressname begin
        # Obtain the initial state.
        state = initialstate(rng, model, sampler, N; kwargs...)

        # Run callback.
        callback(rng, model, sampler, N, 1, state; kwargs...)

        # Save the initial sample.
        samples = initsamples(state, model, sampler, N; kwargs...)
        savesample!(samples, 1, state, model, sampler, N; kwargs...)

        # Update the progress bar.
        progress && ProgressLogging.@logprogress 1/N

        # Step through the sampler.
        for i in 2:N
            # Obtain the updated state.
            state = step(rng, model, sampler, N, state; kwargs...)

            # Run callback.
            callback(rng, model, sampler, N, i, state; kwargs...)

            # Save the sample.
            savesample!(samples, i, state, model, sampler, N; kwargs...)

            # Update the progress bar.
            progress && ProgressLogging.@logprogress i/N
        end
    end

    # Compute the resulting MCMC chain.
    chain = samples2chain(rng, model, sampler, N, samples, state, chain_type; kwargs...)

	return chain, state
end

function samples2chain(
    ::Random.AbstractRNG,
    ::AbstractModel,
    ::AbstractSampler,
    ::Integer,
    samples,
	state,
    ::Type{Any};
    kwargs...
)
    return samples
end

I renamed transitions to samples and used state instead of transition to make it more clear that each individual transition should be thought of as the current state of the sampler and provides all information that is needed to resume sampling.

Maybe one could even introduce a type SamplerWithState that would make it easier to bundle a state with the corresponding sampler and return this after sampling:

struct SamplerWithState{S<:AbstractSampler,T}
	sampler::S
	state::T
end

That could possibly also reduce the amount of arguments in functions such as step or samples2chain.

As a side note, IMO the callback signature could be simplified to callback(state, i) or even just callback(state) since all other arguments are known when the callback is defined and hence can be used by the callback if it is defined as a closure over these arguments.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions