From 352e802b118156282e569c02dbf9642b200b5fad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Feb 2021 17:47:54 +0100 Subject: [PATCH 1/4] Include `state` and `kwargs...` to `callback` --- src/sample.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 9a03f444..6a4966e9 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -87,7 +87,7 @@ function mcmcsample( end # Run callback. - callback === nothing || callback(rng, model, sampler, sample, 1) + callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) @@ -115,7 +115,7 @@ function mcmcsample( sample, state = step(rng, model, sampler, state; kwargs...) # Run callback. - callback === nothing || callback(rng, model, sampler, sample, i) + callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...) # Save the sample. samples = save!!(samples, sample, i, model, sampler, N; kwargs...) From 4365052407b73d73291157312be04385e42ed954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Feb 2021 17:49:34 +0100 Subject: [PATCH 2/4] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 609310b8..c88f08ab 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ are: - `progress` (default: `true`): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + `callback(rng, model, sampler, sample, state, iteration; kwargs...)` is called after every sampling step, where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. From e58fe090fcdd58a59e38ed717b060098f333d001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Thu, 25 Feb 2021 17:55:22 +0100 Subject: [PATCH 3/4] Add test for testing callbacks do the right thing. --- test/sample.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/sample.jl b/test/sample.jl index f50f7881..af9df0f4 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -276,4 +276,14 @@ @test mean(x.b for x in chain) ≈ 0 atol=0.1 @test var(x.b for x in chain) ≈ 1 atol=0.15 end + + @testset "Testing callbacks" begin + function count_iterations(rng, model, sampler, sample, state, i; iter_array, kwargs...) + iter_array[i] = i + end + N = 100 + it_array = zeros(N) + sample(MyModel(), MySampler(), N; callback=count_iterations, iter_array=it_array) + @test it_array == collect(1:N) + end end From 40a339b6def9551dfb8cc344858c424576602374 Mon Sep 17 00:00:00 2001 From: Cameron Pfiffer Date: Tue, 6 Apr 2021 17:15:42 -0700 Subject: [PATCH 4/4] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e7e90dc..477ccc9d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "2.2.1" +version = "3.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"