diff --git a/Project.toml b/Project.toml index 5257e90f8..8ecae82df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.17.1" +version = "0.17.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/model.jl b/src/model.jl index 702d76a17..9890ab0d4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -376,6 +376,16 @@ number of `sampler`. """ (model::Model)(args...) = first(evaluate!!(model, args...)) +""" + use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) + +Return `true` if evaluation of a model using `context` and `varinfo` should +wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. +""" +function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) + return Threads.nthreads() > 1 +end + """ evaluate!!(model::Model[, rng, varinfo, sampler, context]) @@ -388,10 +398,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu number of `sampler`. """ function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - if Threads.nthreads() == 1 - return evaluate_threadunsafe!!(model, varinfo, context) + return if use_threadsafe_eval(context, varinfo) + evaluate_threadsafe!!(model, varinfo, context) else - return evaluate_threadsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo, context) end end