@@ -376,6 +376,16 @@ number of `sampler`.
376376"""
377377(model:: Model )(args... ) = first (evaluate!! (model, args... ))
378378
379+ """
380+ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
381+
382+ Return `true` if evaluation of a model using `context` and `varinfo` should
383+ wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
384+ """
385+ function use_threadsafe_eval (context:: AbstractContext , varinfo:: AbstractVarInfo )
386+ return Threads. nthreads () > 1
387+ end
388+
379389"""
380390 evaluate!!(model::Model[, rng, varinfo, sampler, context])
381391
@@ -388,10 +398,10 @@ The method resets the log joint probability of `varinfo` and increases the evalu
388398number of `sampler`.
389399"""
390400function evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext )
391- if Threads . nthreads () == 1
392- return evaluate_threadunsafe !! (model, varinfo, context)
401+ return if use_threadsafe_eval (context, varinfo)
402+ evaluate_threadsafe !! (model, varinfo, context)
393403 else
394- return evaluate_threadsafe !! (model, varinfo, context)
404+ evaluate_threadunsafe !! (model, varinfo, context)
395405 end
396406end
397407
0 commit comments