Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.33.3"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 5 additions & 3 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ end
# Extended in contrib/inference/abstractmcmc.jl
getstats(t) = nothing

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
abstract type AbstractTransition end

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
θ :: T
lp :: F # TODO: merge `lp` with `stat`
stat :: S
Expand Down Expand Up @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down Expand Up @@ -472,7 +474,7 @@ end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
SMC(space::Symbol...) = SMC(space)
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)

struct SMCTransition{T,F<:AbstractFloat}
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down Expand Up @@ -222,7 +222,7 @@ end

const CSMC = PG # type alias of PG as Conditional SMC

struct PGTransition{T,F<:AbstractFloat}
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function SGLD(
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real}
struct SGLDTransition{T,F<:Real} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample."
Expand Down
8 changes: 4 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})

Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
argument should be either a `Symbol` or a vector of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
Expand All @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
5 changes: 3 additions & 2 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module AquaTests
using Aqua: Aqua
using Turing

# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
# in dependencies. Would like to check it for just Turing.jl itself though.
# We test ambiguities separately because it catches a lot of problems
# in dependencies but we test it for Turing.
Aqua.test_ambiguities([Turing])
Aqua.test_all(Turing; ambiguities=false)

end