diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..b73e1eff6 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -101,6 +101,7 @@ export AbstractVarInfo, pointwise_loglikelihoods, # Convenience macros @addlogprob!, + @isassumption, @submodel # Reexport diff --git a/src/compiler.jl b/src/compiler.jl index 91fe78e2b..1fb985222 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2,39 +2,107 @@ const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__ const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng) """ - isassumption(expr) + @isassumption x + @isassumption model x[, varname] -Return an expression that can be evaluated to check if `expr` is an assumption in the -model. +Return `true` if `x` is an assumption and `false` otherwise. -Let `expr` be `:(x[1])`. It is an assumption in the following cases: - 1. `x` is not among the input data to the model, - 2. `x` is among the input data to the model but with a value `missing`, or - 3. `x` is among the input data to the model with a value other than missing, - but `x[1] === missing`. +E.g. `x[1]` is an assumption in the following cases: + 1. `x` is not among the input data to the model, or + 2. `x` is among the input data to the model but with `value === missing`.! -When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. +A literal, e.g. `1.0`, results in `false`. + +# Examples +```jldoctest +julia> @model demo(x) = x ~ Normal(); # univariate + +julia> @isassumption(demo(1.0), x) +false + +julia> @isassumption(demo(1.0), y) +true + +julia> x = missing; @isassumption(demo(1.0), x) +true + +julia> @model demov(x) = x .~ Normal(); # multivariate + +julia> x = [1.0, 1.0]; + +julia> @isassumption(demov(x), x) +false + +julia> @isassumption(demov(x), y) +true + +julia> @isassumption(demov(x), missing) +true + +julia> x = [1.0, missing]; # partially missing not supported for multivariate + +julia> @isassumption(demov(x), x) +ERROR: x have some `missing` and some not; this is currently not supported + +julia> @isassumption(demov(x), y) +true + +julia> x = [missing, missing]; # fully missing supported for multivariate + +julia> @isassumption(demov(x), x) +true +``` + +See also: [`isassumption`](@ref) """ -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) +macro isassumption(left) + return esc(isassumption(:(__model__), left)) +end +macro isassumption(model, left, vn=varname(left)) + return esc(isassumption(model, left, vn)) +end - return quote - let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - # Evaluate the LHS - $(maybe_view(expr)) === missing - end - end +""" + isassumption(model, left[, vn]) + +Return an expression evaluating to `true` if expression `left` is considered +as an assumption in `model`, where `model` evaluates to a [`Model`](@ref). + +If `vn` is specified, is is assumed to evaluate to `varname(left)`. +If `vn` is not specified, `varname(left)` is used in instead. + +See also: [`@isassumption`](@ref) +""" +function isassumption(model, left, vn=varname(left)) + if isliteral(left) + return :(false) end + + sym = vsym(left) + return :( + (!$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)) || + ( + @isdefined($sym) && + ($left === $(missing) || $(DynamicPPL.is_entirely_missing)($vn, $left)) + ) + ) end -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) +is_entirely_missing(vn, x) = false +function is_entirely_missing(vn::VarName, x::AbstractArray{>:Missing}) + num_missing = count(x -> x === missing, x) + if num_missing == length(x) + # All are `missing`. + return true + end + + if num_missing > 0 + # Only some are `missing` => we don't know what to do. + error("$(vn) have some `missing` and some not; this is currently not supported") + end + + return false +end # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x @@ -317,12 +385,11 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption + if $(isassumption(:__model__, left, vn)) $left = $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)( @@ -361,12 +428,11 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn inds return quote $vn = $(varname(left)) $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) - if $isassumption + if $(isassumption(:__model__, left, vn)) $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..00b853cf0 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -231,7 +231,7 @@ function assume(dist::Distribution, vn::VarName, vi) error("variable $vn does not exist") end r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end # SampleFromPrior and SampleFromUniform