diff --git a/Project.toml b/Project.toml index ba05b2de..0b11ca3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.3" +version = "0.3.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -13,6 +13,6 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] AbstractMCMC = "2, 3" Distributions = "0.23, 0.24, 0.25" -Libtask = "0.6" +Libtask = "0.6.2" StatsFuns = "0.9" julia = "1.3" diff --git a/src/container.jl b/src/container.jl index d2c59ef9..8631938e 100644 --- a/src/container.jl +++ b/src/container.jl @@ -6,7 +6,12 @@ end const Particle = Trace function Trace(f) - ctask = Libtask.CTask(f) + if hasfield(typeof(f), :evaluator) # Test whether f is a Turing.TracedModel + # println(f.evaluator) + ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...) + else # f is a Function, or AdavncedPS.Model + ctask = Libtask.CTask(f) + end # add backward reference newtrace = Trace(f, ctask) @@ -42,7 +47,12 @@ end # Create new task and copy randomness function forkr(trace::Trace) newf = reset_model(trace.f) - ctask = Libtask.CTask(trace.ctask) + # ctask = Libtask.CTask(trace.ctask) + if hasfield(typeof(newf), :evaluator) # Test whether f is a Turing.TracedModel + ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) + else # f is a Function, or AdavncedPS.Model + ctask = Libtask.CTask(newf) + end # add backward reference newtrace = Trace(newf, ctask)