diff --git a/src/compiler.jl b/src/compiler.jl index f973a3bb6..c7b310f46 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,7 +1,7 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) """ - isassumption(expr) + isassumption(expr[, vn]) Return an expression that can be evaluated to check if `expr` is an assumption in the model. @@ -13,39 +13,44 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: but `x[1] === missing`. When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -""" -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) +If `vn` is specified, it will be assumed to refer to a expression which +evaluates to a `VarName`, and this will be used in the subsequent checks. +If `vn` is not specified, `AbstractPPL.drop_escape(varname(expr))` will be +used in its place. +""" +function isassumption(expr::Union{Expr,Symbol}, vn=AbstractPPL.drop_escape(varname(expr))) return quote - let $vn = $(AbstractPPL.drop_escape(varname(expr))) - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) - # Considered an assumption by `__context__` which means either: - # 1. We hit the default implementation, e.g. using `DefaultContext`, - # which in turn means that we haven't considered if it's one of - # the model arguments, hence we need to check this. - # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, - # i.e. we're trying to condition one of the latent variables. - # In this case, the below will return `true` since the first branch - # will be hit. - # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, - # i.e. we're trying to override the value. This is currently NOT supported. - # TODO: Support by adding context to model, and use `model.args` - # as the default conditioning. Then we no longer need to check `inargnames` - # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - $(maybe_view(expr)) === missing - end + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) + true else - false + $(maybe_view(expr)) === missing end + else + false end end end +# failsafe: a literal is never an assumption +isassumption(expr, vn) = :(false) +isassumption(expr) = :(false) + """ contextual_isassumption(context, vn) @@ -79,9 +84,6 @@ function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) end -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) - # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x maybe_view(x::Expr) = :(@views($x)) @@ -382,7 +384,7 @@ function generate_tilde(left, right) # more selective with our escape. Until that's the case, we remove them all. return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_tilde_assume(left, right, vn)) else @@ -439,7 +441,7 @@ function generate_dot_tilde(left, right) @gensym vn isassumption value return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) else