Skip to content

Commit 2f93d75

Browse files
committed
Merge remote-tracking branch 'origin/master' into mhauru/tapir-tests
2 parents fab002d + a26ce11 commit 2f93d75

File tree

10 files changed

+315
-71
lines changed

10 files changed

+315
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.33.3"
3+
version = "0.34.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ end
213213
# Extended in contrib/inference/abstractmcmc.jl
214214
getstats(t) = nothing
215215

216-
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
216+
abstract type AbstractTransition end
217+
218+
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
217219
θ :: T
218220
lp :: F # TODO: merge `lp` with `stat`
219221
stat :: S
@@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
409411
# Default MCMCChains.Chains constructor.
410412
# This is type piracy (at least for SampleFromPrior).
411413
function AbstractMCMC.bundle_samples(
412-
ts::Vector,
414+
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
413415
model::AbstractModel,
414416
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
415417
state,
@@ -472,7 +474,7 @@ end
472474

473475
# This is type piracy (for SampleFromPrior).
474476
function AbstractMCMC.bundle_samples(
475-
ts::Vector,
477+
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
476478
model::AbstractModel,
477479
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
478480
state,

src/mcmc/hmc.jl

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -563,58 +563,3 @@ end
563563
function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...)
564564
return AHMC.Adaptation.NoAdaptation()
565565
end
566-
567-
##########################
568-
# HMC State Constructors #
569-
##########################
570-
571-
function HMCState(
572-
rng::AbstractRNG,
573-
model::Model,
574-
spl::Sampler{<:Hamiltonian},
575-
vi::AbstractVarInfo;
576-
kwargs...,
577-
)
578-
# Link everything if needed.
579-
waslinked = islinked(vi, spl)
580-
if !waslinked
581-
vi = link!!(vi, spl, model)
582-
end
583-
584-
# Get the initial log pdf and gradient functions.
585-
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
586-
logπ = Turing.LogDensityFunction(
587-
vi,
588-
model,
589-
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
590-
)
591-
592-
# Get the metric type.
593-
metricT = getmetricT(spl.alg)
594-
595-
# Create a Hamiltonian.
596-
θ_init = Vector{Float64}(spl.state.vi[spl])
597-
metric = metricT(length(θ_init))
598-
h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
599-
600-
# Find good eps if not provided one
601-
if iszero(spl.alg.ϵ)
602-
ϵ = AHMC.find_good_stepsize(rng, h, θ_init)
603-
@info "Found initial step size" ϵ
604-
else
605-
ϵ = spl.alg.ϵ
606-
end
607-
608-
# Generate a kernel.
609-
kernel = make_ahmc_kernel(spl.alg, ϵ)
610-
611-
# Generate a phasepoint. Replaced during sample_init!
612-
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.
613-
614-
# Unlink everything, if it was indeed linked before.
615-
if waslinked
616-
vi = invlink!!(vi, spl, model)
617-
end
618-
619-
return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
620-
end

src/mcmc/particle_mcmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
SMC(space::Symbol...) = SMC(space)
4646
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)
4747

48-
struct SMCTransition{T,F<:AbstractFloat}
48+
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
4949
"The parameters for any given sample."
5050
θ::T
5151
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
@@ -222,7 +222,7 @@ end
222222

223223
const CSMC = PG # type alias of PG as Conditional SMC
224224

225-
struct PGTransition{T,F<:AbstractFloat}
225+
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
226226
"The parameters for any given sample."
227227
θ::T
228228
"The joint log probability of the sample (NOTE: does not work, always set to zero)."

src/mcmc/sghmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function SGLD(
193193
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
194194
end
195195

196-
struct SGLDTransition{T,F<:Real}
196+
struct SGLDTransition{T,F<:Real} <: AbstractTransition
197197
"The parameters for any given sample."
198198
θ::T
199199
"The joint log probability of the sample."

src/optimisation/Optimisation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp
277277

278278
"""
279279
Base.get(m::ModeResult, var_symbol::Symbol)
280-
Base.get(m::ModeResult, var_symbols)
280+
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
281281
282282
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
283283
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
284-
argument should be either a `Symbol` or an iterator of `Symbol`s.
284+
argument should be either a `Symbol` or a vector of `Symbol`s.
285285
"""
286-
function Base.get(m::ModeResult, var_symbols)
286+
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
287287
log_density = m.f
288288
# Get all the variable names in the model. This is the same as the list of keys in
289289
# m.values, but they are more convenient to filter when they are VarNames rather than
@@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
304304
return (; zip(var_symbols, value_vectors)...)
305305
end
306306

307-
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
307+
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])
308308

309309
"""
310310
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)

test/Aqua.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module AquaTests
33
using Aqua: Aqua
44
using Turing
55

6-
# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
7-
# in dependencies. Would like to check it for just Turing.jl itself though.
6+
# We test ambiguities separately because it catches a lot of problems
7+
# in dependencies but we test it for Turing.
8+
Aqua.test_ambiguities([Turing])
89
Aqua.test_all(Turing; ambiguities=false)
910

1011
end

test/mcmc/hmc.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module HMCTests
22

33
using ..Models: gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
#using ..Models: gdemo
56
using ..NumericalTests: check_gdemo, check_numerical
67
import ..ADUtils
@@ -324,6 +325,15 @@ ADUtils.install_tapir && import Tapir
324325
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
325326
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
326327
end
328+
329+
@testset "Check ADType" begin
330+
alg = HMC(0.1, 10; adtype=adbackend)
331+
m = DynamicPPL.contextualize(
332+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
333+
)
334+
# These will error if the adbackend being used is not the one set.
335+
sample(rng, m, alg, 10)
336+
end
327337
end
328338

329339
end

test/optimisation/Optimisation.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
using Distributions
56
using Distributions.FillArrays: Zeros
67
using DynamicPPL: DynamicPPL
@@ -140,7 +141,6 @@ using Turing
140141
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
141142
)
142143
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
143-
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
144144
m4 = maximum_likelihood(
145145
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
146146
)
@@ -616,6 +616,18 @@ using Turing
616616
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
617617
@assert get(result, :c) == (; :c => Array{Float64}[])
618618
end
619+
620+
@testset "ADType" begin
621+
Random.seed!(222)
622+
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
623+
m = DynamicPPL.contextualize(
624+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
625+
)
626+
# These will error if the adbackend being used is not the one set.
627+
maximum_likelihood(m; adtype=adbackend)
628+
maximum_a_posteriori(m; adtype=adbackend)
629+
end
630+
end
619631
end
620632

621633
end

0 commit comments

Comments
 (0)