From b19c2ee39ea49db534918e621859bb45b24e0f01 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 00:57:07 +0000 Subject: [PATCH 1/8] fixed bug in replace_returns --- src/compiler.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c7b310f46..b9c7eead9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -518,16 +518,17 @@ function replace_returns(e::Expr) end if Meta.isexpr(e, :return) - # NOTE: `return` always has an argument. In the case of - # an empty `return`, the lowered expression will be `return nothing`. - # Hence we don't need any special handling for empty returns. - retval_expr = if length(e.args) > 1 - Expr(:tuple, e.args...) - else - e.args[1] + # We capture the original return-value in `retval` and return + # a `Tuple{typeof(retval),typeof(__varinfo__)}`. + # If we don't capture the return-value separately, cases such as + # `return x = 1` will result in `(x = 1, __varinfo__)` which will + # mistakenly attempt to construct a `NamedTuple` (which fails on Julia 1.3 + # and is not our intent). + @gensym retval + return quote + $retval = $(e.args...) + return $retval, __varinfo__ end - - return :(return ($retval_expr, __varinfo__)) end return Expr(e.head, map(replace_returns, e.args)...) From cf2cf656e50cc201975472590ab690825ecd3a28 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 00:57:34 +0000 Subject: [PATCH 2/8] added test for empty_model --- test/compiler.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index 0544bb5f6..a3ef5f6bd 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -546,6 +546,13 @@ end end @testset "return value" begin + # Make sure that a return-value of `x = 1` isn't combined into + # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. + @model empty_model() = begin x = 1; end + empty_vi = VarInfo() + retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) + @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} + # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ From 2a4cc569e4b7b3932f2670bf158e877a5b35f711 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 01:04:15 +0000 Subject: [PATCH 3/8] bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 844289af5..5257e90f8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.0" +version = "0.17.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From bdc3e650b195076b8b23689f8c79f032f5509266 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 01:08:06 +0000 Subject: [PATCH 4/8] styling for test --- test/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index a3ef5f6bd..d242d52eb 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -548,7 +548,7 @@ end @testset "return value" begin # Make sure that a return-value of `x = 1` isn't combined into # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. - @model empty_model() = begin x = 1; end + @model empty_model() = return x = 1; empty_vi = VarInfo() retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} From 34f22e8a0dbf723766f719b18037e88c33d015dc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 01:10:31 +0000 Subject: [PATCH 5/8] Update test/compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index d242d52eb..9dc81ff16 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -548,7 +548,7 @@ end @testset "return value" begin # Make sure that a return-value of `x = 1` isn't combined into # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. - @model empty_model() = return x = 1; + @model empty_model() = return x = 1 empty_vi = VarInfo() retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} From eec5c65d788fa27a244a15007bfd92564f1c1ee4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 20:37:35 +0000 Subject: [PATCH 6/8] added requires_threadsafe to allow samplers to decide whether to use threadsafe varinfo or not --- Project.toml | 2 +- src/model.jl | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 5257e90f8..8ecae82df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.1" +version = "0.17.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/model.jl b/src/model.jl index 702d76a17..1994c192d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -376,6 +376,17 @@ number of `sampler`. """ (model::Model)(args...) = first(evaluate!!(model, args...)) + +""" + requires_threadsafe(context::AbstractContext, varinfo::AbstractVarInfo) + +Return `true` if evaluation of a model using `context` and `varinfo` requires +and supports wrapping `varinfo` in `ThreadSafeVarInfo`, and `false` otherwise. +""" +function requires_threadsafe(context::AbstractContext, varinfo::AbstractVarInfo) + return Threads.nthreads() > 1 +end + """ evaluate!!(model::Model[, rng, varinfo, sampler, context]) @@ -388,10 +399,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu number of `sampler`. """ function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - if Threads.nthreads() == 1 - return evaluate_threadunsafe!!(model, varinfo, context) + return if requires_threadsafe(context, varinfo) + evaluate_threadsafe!!(model, varinfo, context) else - return evaluate_threadsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo, context) end end From e2f4fb5bba9eee85db0e65948b3ac5a2243ecf96 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 20:48:11 +0000 Subject: [PATCH 7/8] Update src/model.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 1994c192d..8ad2f9f2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -376,7 +376,6 @@ number of `sampler`. """ (model::Model)(args...) = first(evaluate!!(model, args...)) - """ requires_threadsafe(context::AbstractContext, varinfo::AbstractVarInfo) From 26aa02f2bf9f4b003584181feca30fb725927761 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 15 Dec 2021 21:40:19 +0000 Subject: [PATCH 8/8] rename requires_threadsafe to use_threadsafe_eval --- src/model.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/model.jl b/src/model.jl index 1994c192d..16af76b5b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -378,12 +378,12 @@ number of `sampler`. """ - requires_threadsafe(context::AbstractContext, varinfo::AbstractVarInfo) + use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) -Return `true` if evaluation of a model using `context` and `varinfo` requires -and supports wrapping `varinfo` in `ThreadSafeVarInfo`, and `false` otherwise. +Return `true` if evaluation of a model using `context` and `varinfo` should +wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. """ -function requires_threadsafe(context::AbstractContext, varinfo::AbstractVarInfo) +function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) return Threads.nthreads() > 1 end @@ -399,7 +399,7 @@ The method resets the log joint probability of `varinfo` and increases the evalu number of `sampler`. """ function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - return if requires_threadsafe(context, varinfo) + return if use_threadsafe_eval(context, varinfo) evaluate_threadsafe!!(model, varinfo, context) else evaluate_threadunsafe!!(model, varinfo, context)