diff --git a/Project.toml b/Project.toml index 7b793cf3..477ccc9d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "2.5.0" +version = "3.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/README.md b/README.md index b1dfc7de..56b8093d 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + `callback(rng, model, sampler, sample, state, iteration; kwargs...)` is called after every sampling step, where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. diff --git a/src/sample.jl b/src/sample.jl index 45505b31..563d4652 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -109,7 +109,7 @@ function mcmcsample( end # Run callback. - callback === nothing || callback(rng, model, sampler, sample, 1) + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) @@ -140,7 +140,7 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, i) + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) diff --git a/test/sample.jl b/test/sample.jl index 8b2ea06c..00f7ccae 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -289,4 +289,14 @@ @test mean(x.b for x in chain) ≈ 0 atol=0.1 @test var(x.b for x in chain) ≈ 1 atol=0.15 end + + @testset "Testing callbacks" begin + function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) + iter_array[i] = i + end + N = 100 + it_array = zeros(N) + sample(MyModel(), MySampler(), N; callback=count_iterations, iter_array=it_array) + @test it_array == collect(1:N) + end end