diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..91fe78e2b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -27,7 +27,7 @@ function isassumption(expr::Union{Symbol,Expr}) true else # Evaluate the LHS - $expr === missing + $(maybe_view(expr)) === missing end end end @@ -36,6 +36,16 @@ end # failsafe: a literal is never an assumption isassumption(expr) = :(false) +# If we're working with, say, a `Symbol`, then we're not going to `view`. +maybe_view(x) = x +maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) + +# If the result of a `view` is a zero-dim array then it's just a +# single element. Likely the rest is expecting type `eltype(x)`, hence +# we extract the value rather than passing the array. +maybe_unwrap_view(x) = x +maybe_unwrap_view(x::SubArray{<:Any,0}) = x[1] + """ isliteral(expr) @@ -325,7 +335,7 @@ function generate_tilde(left, right) $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), - $left, + $(maybe_view(left)), $vn, $inds, __varinfo__, @@ -360,7 +370,7 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn )..., $inds, __varinfo__, @@ -369,7 +379,7 @@ function generate_dot_tilde(left, right) $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), - $left, + $(maybe_view(left)), $vn, $inds, __varinfo__,