Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.11.4"
version = "0.12.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
9 changes: 6 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,20 @@ export AbstractVarInfo,
SampleFromPrior,
SampleFromUniform,
# Contexts
SamplingContext,
DefaultContext,
LikelihoodContext,
PriorContext,
MiniBatchContext,
PrefixContext,
assume,
dot_assume,
observer,
observe,
dot_observe,
tilde,
dot_tilde,
tilde_assume,
tilde_observe,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down
103 changes: 72 additions & 31 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,49 @@ end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x

"""
unwrap_right_vn(right, vn)

Return the unwrapped distribution on the right-hand side and variable name on the left-hand
side of a `~` expression such as `x ~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions.
"""
unwrap_right_vn(right, vn) = right, vn
unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name)

"""
unwrap_right_left_vns(right, left, vns)

Return the unwrapped distributions on the right-hand side and values and variable names on the
left-hand side of a `.~` expression such as `x .~ Normal()`.

This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the
variables.
"""
unwrap_right_left_vns(right, left, vns) = right, left, vns
function unwrap_right_left_vns(right::NamedDist, left, vns)
return unwrap_right_left_vns(right.dist, left, right.name)
end
function unwrap_right_left_vns(
right::MultivariateDistribution, left::AbstractMatrix, vn::VarName
)
vns = map(axes(left, 2)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
Copy link
Member

@yebai yebai Jun 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly confused here - is the argument Tuple(i) redundant with vn.indexing?

EDIT: Got the answer after reading the previous code. Maybe consider adding a comment here?

end
return unwrap_right_left_vns(right, left, vns)
end
function unwrap_right_left_vns(
right::Union{Distribution,AbstractArray{<:Distribution}},
left::AbstractArray,
vn::VarName,
)
vns = map(CartesianIndices(left)) do i
return VarName(vn, (vn.indexing..., Tuple(i)))
end
return unwrap_right_left_vns(right, left, vns)
end

#################
# Main Compiler #
#################
Expand Down Expand Up @@ -256,12 +299,8 @@ function generate_tilde(left, right)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
Expand All @@ -274,19 +313,17 @@ function generate_tilde(left, right)
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = $(DynamicPPL.tilde_assume)(
__rng__,
$left = $(DynamicPPL.tilde_assume!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$vn,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
)
else
$(DynamicPPL.tilde_observe)(
$(DynamicPPL.tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
Expand All @@ -306,12 +343,8 @@ function generate_dot_tilde(left, right)
# If the LHS is a literal, it is always an observation
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
__varinfo__,
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
)
end
end
Expand All @@ -324,20 +357,17 @@ function generate_dot_tilde(left, right)
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume)(
__rng__,
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
)...,
$inds,
__varinfo__,
)
else
$(DynamicPPL.dot_tilde_observe)(
$(DynamicPPL.dot_tilde_observe!)(
__context__,
__sampler__,
$(DynamicPPL.check_tilde_rhs)($right),
$left,
$vn,
Expand Down Expand Up @@ -368,10 +398,8 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__rng__::$(Random.AbstractRNG)),
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__sampler__::$(DynamicPPL.AbstractSampler)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
Expand Down Expand Up @@ -419,8 +447,12 @@ end

"""
matchingvalue(sampler, vi, value)
matchingvalue(context::AbstractContext, vi, value)

Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object.

Convert the `value` to the correct type for the `sampler` and the `vi` object.
For a `context` that is _not_ a `SamplingContext`, we fall back to
`matchingvalue(SampleFromPrior(), vi, value)`.
"""
function matchingvalue(sampler, vi, value)
T = typeof(value)
Expand All @@ -435,7 +467,16 @@ function matchingvalue(sampler, vi, value)
return value
end
end
matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value)
function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
return get_matching_type(sampler, vi, value)
end

function matchingvalue(context::AbstractContext, vi, value)
return matchingvalue(SampleFromPrior(), vi, value)
end
function matchingvalue(context::SamplingContext, vi, value)
return matchingvalue(context.sampler, vi, value)
end

"""
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T}
Expand Down
Loading