Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,44 @@ function generate_mainbody!(found, expr::Expr, args, warn)
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn)
end

if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 && expr.args[1] === Symbol("@reparam")
# @reparam f x ~ dist
# but `expr.args[2]` is the line-information, so we identify `f` and `x ~ dist` by indexing from `end`.
reparam = expr.args[end - 1]

# Don't override the `expr` in main body
let expr = expr.args[end]
# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return generate_dot_tilde_with_reparam(
generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
args,
reparam
) |> Base.remove_linenums!
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return generate_tilde_with_reparam(
generate_mainbody!(found, L, args, warn),
generate_mainbody!(found, R, args, warn),
args,
reparam
) |> Base.remove_linenums!
end

return Expr(
expr.head,
map(x -> generate_mainbody!(found, x, args, warn), expr.args)...
)
end
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
Expand All @@ -205,7 +243,6 @@ function generate_mainbody!(found, expr::Expr, args, warn)
end



"""
generate_tilde(left, right, args)

Expand Down Expand Up @@ -297,6 +334,156 @@ function generate_dot_tilde(left, right, args)
end
end

"""
generate_tilde_with_reparam(left, right, args, reparam)

Generate an `observe` expression for data variables and `assume` expression for parameter
variables with reparameterization.
"""
function generate_tilde_with_reparam(left, right, args, reparam)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds left_intermediate
push!(top, :($vn = $(DynamicPPL.varname2intermediate)($(varname(left)))))
push!(top, :($inds = $(vinds(left))))

# `reparam` might be a `Bijectors.AbstractBijector` which we only really want to
# compute once per sampling statement, and so we "cache" the constructor.
@gensym f
push!(top, :($f = $reparam))

# It can only be an observation if the LHS is an argument of the model
if DynamicPPL.vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left = begin
$left_intermediate = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler,
$tmpright, $vn, $inds, _varinfo
)

$f($left_intermediate)
end
else
if $f isa $(Bijectors.AbstractBijector)
$left_intermediate = inv($f)($left)
$(DynamicPPL.tilde_observe)(
_context, _sampler,
$tmpright, $left_intermediate, $vn, $inds, _varinfo
)
else
throw(ArgumentError("cannot observe non-invertible reparameterization!!!"))
end
end
end
end

return quote
$(top...)
$left = begin
$left_intermediate = $(DynamicPPL.tilde_assume)(
_rng, _context, _sampler, $tmpright, $vn,
$inds, _varinfo
)

$f($left_intermediate)
end
end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
if $f isa $(Bijectors.AbstractBijector)
$left_intermediate = inv($f)($left)
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left_intermediate, _varinfo)
else
throw(ArgumentError("cannot observe non-invertible reparameterization!!!"))
end
end
end

"""
generate_dot_tilde_with_reparam(left, right, args, reparam)

Generate the expression that replaces `@reparam f left .~ right` in the model body.
"""
function generate_dot_tilde_with_reparam(left, right, args, reparam)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
|| throw(ArgumentError($DISTMSG)))]

if left isa Symbol || left isa Expr
@gensym out vn inds left_intermediate
push!(top, :($vn = $(DynamicPPL.varname2intermediate)($(varname(left)))))
push!(top, :($inds = $(vinds(left))))

# `reparam` might be a `Bijectors.AbstractBijector` which we only really want to
# compute once per sampling statement, and so we "cache" the constructor.
@gensym f
push!(top, :($f = $reparam))

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$left .= begin
$left_intermediate = $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo
)

$f.($left_intermediate)
end
else
if $f isa $(Bijectors.AbstractBijector)
$left_intermediate = inv($f).($left)
$(DynamicPPL.dot_tilde_observe)(
_context, _sampler, $tmpright, $left_intermediate, $vn, $inds, _varinfo
)
else
throw(ArgumentError("cannot observe non-invertible reparameterization!!!"))
end
end
end
end

return quote
$(top...)
$left .= begin
$left_intermediate = $(DynamicPPL.dot_tilde_assume)(
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo
)

$f.($left_intermediate)
end

end
end

# If the LHS is a literal, it is always an observation
return quote
$(top...)
if $f isa $(Bijectors.AbstractBijector)
$left_intermediate = inv($f).($left)
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left_intermediate, _varinfo)
else
throw(ArgumentError("cannot observe non-invertible reparameterization!!!"))
end
end
end


const FloatOrArrayType = Type{<:Union{AbstractFloat, AbstractArray}}
hasmissing(T::Type{<:AbstractArray{TA}}) where {TA <: AbstractArray} = hasmissing(TA)
hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true
Expand Down
11 changes: 11 additions & 0 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,14 @@ end
@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
return s in missings
end

"""
varname2intermediate(vn::VarName)

Converts `vn` into a `VarName` representing a "intermediate" variable used to obtain `vn`.

Essentially just adds an underscore to the name, e.g. `x[:, 1]` becomes `x_[:, 1]`.
"""
function varname2intermediate(vn::VarName{sym}) where {sym}
DynamicPPL.VarName{Symbol(sym, "_"), typeof(vn.indexing)}(vn.indexing)
end