From 7b4a9724108614557d145b05c1c09a1d05d8cf7d Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 3 Jan 2022 21:21:40 +0000 Subject: [PATCH 1/7] Enable new model evaluation API. --- src/container.jl | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/container.jl b/src/container.jl index d2c59ef9..403a42d3 100644 --- a/src/container.jl +++ b/src/container.jl @@ -5,12 +5,20 @@ end const Particle = Trace -function Trace(f) - ctask = Libtask.CTask(f) +function AdvancedPS.Trace(f) + if f isa Function + ctask = Libtask.CTask(f) + else + # println(f.evaluator) + # CTask(f, args) seems buggy, so we still use CTask(f::TracedModel) for now. + # TODO: fix this after Libtask bug is resolved. + ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...) + # ctask = Libtask.CTask(f) + end # add backward reference - newtrace = Trace(f, ctask) - addreference!(ctask.task, newtrace) + newtrace = AdvancedPS.Trace(f, ctask) + AdvancedPS.addreference!(ctask.task, newtrace) return newtrace end @@ -42,7 +50,8 @@ end # Create new task and copy randomness function forkr(trace::Trace) newf = reset_model(trace.f) - ctask = Libtask.CTask(trace.ctask) + # ctask = Libtask.CTask(trace.ctask) + ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) # add backward reference newtrace = Trace(newf, ctask) From f2c569487da2393cac268428e95ddfd947f50084 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 3 Jan 2022 21:38:34 +0000 Subject: [PATCH 2/7] Apply suggestions from code review --- src/container.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/container.jl b/src/container.jl index 403a42d3..ec6f8f83 100644 --- a/src/container.jl +++ b/src/container.jl @@ -5,7 +5,7 @@ end const Particle = Trace -function AdvancedPS.Trace(f) +function Trace(f) if f isa Function ctask = Libtask.CTask(f) else @@ -17,8 +17,8 @@ function AdvancedPS.Trace(f) end # add backward reference - newtrace = AdvancedPS.Trace(f, ctask) - AdvancedPS.addreference!(ctask.task, newtrace) + newtrace = Trace(f, ctask) + addreference!(ctask.task, newtrace) return newtrace end From 6fb4243de25c064c37611ec889144e94afc172e4 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 3 Jan 2022 21:39:44 +0000 Subject: [PATCH 3/7] Apply suggestions from code review --- src/container.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/container.jl b/src/container.jl index ec6f8f83..acca23b1 100644 --- a/src/container.jl +++ b/src/container.jl @@ -10,8 +10,6 @@ function Trace(f) ctask = Libtask.CTask(f) else # println(f.evaluator) - # CTask(f, args) seems buggy, so we still use CTask(f::TracedModel) for now. - # TODO: fix this after Libtask bug is resolved. ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...) # ctask = Libtask.CTask(f) end From a8f94adc7b9ebbd377a1963c96c99c234ce53781 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 4 Jan 2022 09:10:16 +0000 Subject: [PATCH 4/7] Update container.jl --- src/container.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/container.jl b/src/container.jl index acca23b1..65175bc0 100644 --- a/src/container.jl +++ b/src/container.jl @@ -48,8 +48,12 @@ end # Create new task and copy randomness function forkr(trace::Trace) newf = reset_model(trace.f) - # ctask = Libtask.CTask(trace.ctask) - ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) + # ctask = Libtask.CTask(trace.ctask) + if newf isa Function + ctask = Libtask.CTask(newf) + else + ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) + end # add backward reference newtrace = Trace(newf, ctask) From 5cf06d031d0a70dc539be49d7d7e3bc742f2d67c Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 4 Jan 2022 09:10:35 +0000 Subject: [PATCH 5/7] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ba05b2de..40faca94 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,6 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] AbstractMCMC = "2, 3" Distributions = "0.23, 0.24, 0.25" -Libtask = "0.6" +Libtask = "0.6.2" StatsFuns = "0.9" julia = "1.3" From 27493f3ae2427406e438becb64a9340afafb4463 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 4 Jan 2022 09:20:23 +0000 Subject: [PATCH 6/7] Minor bugfix. --- src/container.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/container.jl b/src/container.jl index 65175bc0..8631938e 100644 --- a/src/container.jl +++ b/src/container.jl @@ -6,12 +6,11 @@ end const Particle = Trace function Trace(f) - if f isa Function - ctask = Libtask.CTask(f) - else + if hasfield(typeof(f), :evaluator) # Test whether f is a Turing.TracedModel # println(f.evaluator) ctask = Libtask.CTask(f.evaluator[1], f.evaluator[2:end]...) - # ctask = Libtask.CTask(f) + else # f is a Function, or AdavncedPS.Model + ctask = Libtask.CTask(f) end # add backward reference @@ -49,10 +48,10 @@ end function forkr(trace::Trace) newf = reset_model(trace.f) # ctask = Libtask.CTask(trace.ctask) - if newf isa Function - ctask = Libtask.CTask(newf) - else + if hasfield(typeof(newf), :evaluator) # Test whether f is a Turing.TracedModel ctask = Libtask.CTask(newf.evaluator[1], newf.evaluator[2:end]...) + else # f is a Function, or AdavncedPS.Model + ctask = Libtask.CTask(newf) end # add backward reference From 15e57561fe6f0c43b4d487a92921468853c95031 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 4 Jan 2022 09:55:11 +0000 Subject: [PATCH 7/7] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 40faca94..0b11ca3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.3" +version = "0.3.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"