Skip to content
18 changes: 14 additions & 4 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function isassumption(expr::Union{Symbol,Expr})
true
else
# Evaluate the LHS
$expr === missing
$(maybe_view(expr)) === missing
end
end
end
Expand All @@ -36,6 +36,16 @@ end
# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x)))

# If the result of a `view` is a zero-dim array then it's just a
# single element. Likely the rest is expecting type `eltype(x)`, hence
# we extract the value rather than passing the array.
maybe_unwrap_view(x) = x
maybe_unwrap_view(x::SubArray{<:Any,0}) = x[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment on why we need unwrap view here? Otherwise, happy to merge as-is.


"""
isliteral(expr)

Expand Down Expand Up @@ -325,7 +335,7 @@ function generate_tilde(left, right)
$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
Expand Down Expand Up @@ -360,7 +370,7 @@ function generate_dot_tilde(left, right)
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
$inds,
__varinfo__,
Expand All @@ -369,7 +379,7 @@ function generate_dot_tilde(left, right)
$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
Expand Down