Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Release 0.39.0

## Removal of Turing.Essential

The Turing.Essential module has been removed.
Anything exported from there can be imported from either `Turing` or `DynamicPPL`.

## `@addlogprob!`

The `@addlogprob!` macro is now exported from Turing, making it officially part of the public interface.

# Release 0.38.4

The minimum Julia version was increased to 1.10.2 (from 1.10.0).
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
| `to_submodel` | [`DynamicPPL.to_submodel`](@extref) | Define a submodel |
| `prefix` | [`DynamicPPL.prefix`](@extref) | Prefix all variable names in a model with a given VarName |
| `LogDensityFunction` | [`DynamicPPL.LogDensityFunction`](@extref) | A struct containing all information about how to evaluate a model. Mostly for advanced users |
| `@addlogprob!` | [`DynamicPPL.@addlogprob!`](@extref) | Add arbitrary log-probability terms during model evaluation |

### Inference

Expand Down
14 changes: 8 additions & 6 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using Printf: Printf
using Random: Random
using LinearAlgebra: I

using ADTypes: ADTypes
using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake

const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff()

Expand All @@ -47,8 +47,6 @@ end
# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")
include("essential/Essential.jl")
using .Essential
include("mcmc/Inference.jl") # inference algorithms
using .Inference
include("variational/VariationalInference.jl")
Expand All @@ -57,13 +55,13 @@ using .Variational
include("optimisation/Optimisation.jl")
using .Optimisation

include("deprecated.jl") # to be removed in the next minor version release

###########
# Exports #
###########
# `using` statements for stuff to re-export
using DynamicPPL:
@model,
@varname,
pointwise_loglikelihoods,
generated_quantities,
returned,
Expand All @@ -73,9 +71,12 @@ using DynamicPPL:
decondition,
fix,
unfix,
prefix,
conditioned,
@submodel,
to_submodel,
LogDensityFunction
LogDensityFunction,
@addlogprob!
using StatsBase: predict
using OrderedCollections: OrderedDict

Expand All @@ -90,6 +91,7 @@ export
to_submodel,
prefix,
LogDensityFunction,
@addlogprob!,
# Sampling - AbstractMCMC
sample,
MCMCThreads,
Expand Down
39 changes: 0 additions & 39 deletions src/deprecated.jl

This file was deleted.

24 changes: 0 additions & 24 deletions src/essential/Essential.jl

This file was deleted.

70 changes: 0 additions & 70 deletions src/essential/container.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Inference

using ..Essential
using DynamicPPL:
@model,
Metadata,
VarInfo,
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only
Expand Down
75 changes: 74 additions & 1 deletion src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,79 @@
### Particle Filtering and Particle MCMC Samplers.
###

### AdvancedPS models and interface

struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
AdvancedPS.AbstractGenericModel
model::M
sampler::S
varinfo::V
evaluator::E
end

function TracedModel(
model::Model,
sampler::AbstractSampler,
varinfo::AbstractVarInfo,
rng::Random.AbstractRNG,
)
context = SamplingContext(rng, sampler, DefaultContext())
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
if kwargs !== nothing && !isempty(kwargs)
error(
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
)
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model, sampler, varinfo, (model.f, args...)
)
end

function AdvancedPS.advance!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
)
# Make sure we load/reset the rng in the new replaying mechanism
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
score = consume(trace.model.ctask)
if score === nothing
return nothing
else
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
end
end

function AdvancedPS.delete_retained!(trace::TracedModel)
DynamicPPL.set_retained_vns_del!(trace.varinfo)
return trace
end

function AdvancedPS.reset_model(trace::TracedModel)
DynamicPPL.reset_num_produce!(trace.varinfo)
return trace
end

function AdvancedPS.reset_logprob!(trace::TracedModel)
DynamicPPL.resetlogp!!(trace.model.varinfo)
return trace

Check warning on line 59 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L57-L59

Added lines #L57 - L59 were not covered by tests
end

function AdvancedPS.update_rng!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
)
# Extract the `args`.
args = trace.model.ctask.args
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
end

function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
end

####
#### Generic Sequential Monte Carlo sampler.
####
Expand Down Expand Up @@ -408,7 +481,7 @@
newvarinfo = deepcopy(varinfo)
DynamicPPL.reset_num_produce!(newvarinfo)

tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng)
tmodel = TracedModel(model, sampler, newvarinfo, rng)
newtrace = AdvancedPS.Trace(tmodel, rng)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
Expand Down
20 changes: 10 additions & 10 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Ration

"""A dictionary mapping ADTypes to the element types they use."""
eltypes_by_adtype = Dict(
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
Turing.AutoReverseDiff => (
AutoForwardDiff => (ForwardDiff.Dual,),
AutoReverseDiff => (
ReverseDiff.TrackedArray,
ReverseDiff.TrackedMatrix,
ReverseDiff.TrackedReal,
Expand All @@ -37,7 +37,7 @@ eltypes_by_adtype = Dict(
),
)
if INCLUDE_MOONCAKE
eltypes_by_adtype[Turing.AutoMooncake] = (Mooncake.CoDual,)
eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,)
end

"""
Expand Down Expand Up @@ -189,32 +189,32 @@ end
"""
All the ADTypes on which we want to run the tests.
"""
ADTYPES = [Turing.AutoForwardDiff(), Turing.AutoReverseDiff(; compile=false)]
ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
if INCLUDE_MOONCAKE
push!(ADTYPES, Turing.AutoMooncake(; config=nothing))
push!(ADTYPES, AutoMooncake(; config=nothing))
end

# Check that ADTypeCheckContext itself works as expected.
@testset "ADTypeCheckContext" begin
@model test_model() = x ~ Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
AutoForwardDiff(),
AutoReverseDiff(),
# Don't need to test Mooncake as it doesn't use tracer types
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
sampler = HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
sample(contextualised_tm, sampler, 2)
else
@test_throws AbstractWrongADBackendError Turing.sample(
@test_throws AbstractWrongADBackendError sample(
contextualised_tm, sampler, 2
)
end
Expand Down
Loading