diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8c598a6a8..6a8f0ccb1 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -54,7 +54,7 @@ function DynamicPPL.generated_quantities( # TODO: Some of the variables can be a view into the `varinfo`, so we need to # `deepcopy` the `varinfo` before passing it to `model`. - model(deepcopy(varinfo)) + model(deepcopy(varinfo), DynamicPPL.PostProcessingContext()) end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 0ccfbb103..8ad05b068 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -129,6 +129,7 @@ export AbstractVarInfo, unfix, # Convenience macros @addlogprob!, + @is_post_processing, @submodel, value_iterator_from_chain diff --git a/src/compiler.jl b/src/compiler.jl index f8a04a557..c6eecd86b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,5 +1,27 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +""" + check_if_in_model_block_expr(name) + +Return an expression that can be evaluated to check if we're inside a model block. + +# Arguments +- `name`: The name of the variable or method that can only be used inside a model block. + Error message will include this name. +""" +function check_if_in_model_block_expr(name) + return Expr( + :||, + Expr( + :&&, + Expr(:isdefined, esc(:__model__)), + Expr(:call, :isa, esc(:__model__), Model), + ), + # Otherwise, throw error. + :(error($(string(name)) * " can only be used inside a model block")), + ) +end + """ need_concretize(expr) diff --git a/src/contexts.jl b/src/contexts.jl index 2018b9155..f5bed3dda 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -664,3 +664,66 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return merge(context.values, fixed(childcontext(context))) end + +"""" + PostProcessingContext + +Simple context used to indicate that the model is being evaluated with the aim +of post-processing the inference results, e.g. making predictions or computing +generated quantities. +""" +struct PostProcessingContext{Ctx} <: AbstractContext + context::AbstractContext +end + +function PostProcessingContext(context::AbstractContext) + return PostProcessingContext{typeof(context)}(context) +end +PostProcessingContext() = PostProcessingContext(DefaultContext()) + +NodeTrait(::PostProcessingContext) = IsParent() +childcontext(context::PostProcessingContext) = context.context +function setchildcontext(context::PostProcessingContext, child) + return PostProcessingContext(child) +end + +function is_post_processing(context::AbstractContext) + return is_post_processing(NodeTrait(is_post_processing, context), context) +end +is_post_processing(::IsLeaf, context) = false +is_post_processing(::IsParent, context) = is_post_processing(childcontext(context)) +is_post_processing(context::PostProcessingContext) = true + +""" + @is_post_processing + +Return `true` if the model is being evaluated with the aim of post-processing +inference results, e.g. making predictions or computing generated quantities. + +# Examples + +```jldoctest; setup = :(using Distributions) +julia> @model function demo() + x ~ Normal(0, 1) + return if @is_post_processing + x + else + nothing + end + end +demo (generic function with 2 methods) + +julia> model = demo(); + +julia> model() # (✓) Returns nothing + +julia> generated_quantities(model, (x = 1,)) # (✓) Returns 1.0 +1.0 +``` +""" +macro is_post_processing() + return quote + $(check_if_in_model_block_expr("@is_post_processing")) + $(is_post_processing)($(esc(:__context__))) + end +end diff --git a/src/model.jl b/src/model.jl index 8c10ed36e..2825081ee 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1253,7 +1253,7 @@ function generated_quantities(model::Model, chain::AbstractChains) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - model(varinfo) + model(varinfo, PostProcessingContext()) end end @@ -1295,11 +1295,11 @@ julia> generated_quantities(model, values(parameters), keys(parameters)) function generated_quantities(model::Model, parameters::NamedTuple) varinfo = VarInfo(model) setval_and_resample!(varinfo, values(parameters), keys(parameters)) - return model(varinfo) + return model(varinfo, PostProcessingContext()) end function generated_quantities(model::Model, values, keys) varinfo = VarInfo(model) setval_and_resample!(varinfo, values, keys) - return model(varinfo) + return model(varinfo, PostProcessingContext()) end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 050bf31fc..479a52214 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -241,6 +241,7 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) ) end quote + $(check_if_in_model_block_expr("@submodel")) $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) diff --git a/src/utils.jl b/src/utils.jl index 15c7078b8..2e4cbbc1c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -70,6 +70,7 @@ true """ macro addlogprob!(ex) return quote + $(check_if_in_model_block_expr("@addlogprob!")) $(esc(:(__varinfo__))) = acclogp!!( $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) )