diff --git a/src/compiler.jl b/src/compiler.jl index 73ae707c5..af468e61d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -183,6 +183,44 @@ function generate_mainbody!(found, expr::Expr, args, warn) return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args, warn) end + if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 && expr.args[1] === Symbol("@reparam") + # @reparam f x ~ dist + # but `expr.args[2]` is the line-information, so we identify `f` and `x ~ dist` by indexing from `end`. + reparam = expr.args[end - 1] + + # Don't override the `expr` in main body + let expr = expr.args[end] + # Modify dotted tilde operators. + args_dottilde = getargs_dottilde(expr) + if args_dottilde !== nothing + L, R = args_dottilde + return generate_dot_tilde_with_reparam( + generate_mainbody!(found, L, args, warn), + generate_mainbody!(found, R, args, warn), + args, + reparam + ) |> Base.remove_linenums! + end + + # Modify tilde operators. + args_tilde = getargs_tilde(expr) + if args_tilde !== nothing + L, R = args_tilde + return generate_tilde_with_reparam( + generate_mainbody!(found, L, args, warn), + generate_mainbody!(found, R, args, warn), + args, + reparam + ) |> Base.remove_linenums! + end + + return Expr( + expr.head, + map(x -> generate_mainbody!(found, x, args, warn), expr.args)... + ) + end + end + # Modify dotted tilde operators. args_dottilde = getargs_dottilde(expr) if args_dottilde !== nothing @@ -205,7 +243,6 @@ function generate_mainbody!(found, expr::Expr, args, warn) end - """ generate_tilde(left, right, args) @@ -297,6 +334,156 @@ function generate_dot_tilde(left, right, args) end end +""" + generate_tilde_with_reparam(left, right, args, reparam) + +Generate an `observe` expression for data variables and `assume` expression for parameter +variables with reparameterization. +""" +function generate_tilde_with_reparam(left, right, args, reparam) + @gensym tmpright + top = [:($tmpright = $right), + :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} + || throw(ArgumentError($DISTMSG)))] + + if left isa Symbol || left isa Expr + @gensym out vn inds left_intermediate + push!(top, :($vn = $(DynamicPPL.varname2intermediate)($(varname(left))))) + push!(top, :($inds = $(vinds(left)))) + + # `reparam` might be a `Bijectors.AbstractBijector` which we only really want to + # compute once per sampling statement, and so we "cache" the constructor. + @gensym f + push!(top, :($f = $reparam)) + + # It can only be an observation if the LHS is an argument of the model + if DynamicPPL.vsym(left) in args + @gensym isassumption + return quote + $(top...) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $left = begin + $left_intermediate = $(DynamicPPL.tilde_assume)( + _rng, _context, _sampler, + $tmpright, $vn, $inds, _varinfo + ) + + $f($left_intermediate) + end + else + if $f isa $(Bijectors.AbstractBijector) + $left_intermediate = inv($f)($left) + $(DynamicPPL.tilde_observe)( + _context, _sampler, + $tmpright, $left_intermediate, $vn, $inds, _varinfo + ) + else + throw(ArgumentError("cannot observe non-invertible reparameterization!!!")) + end + end + end + end + + return quote + $(top...) + $left = begin + $left_intermediate = $(DynamicPPL.tilde_assume)( + _rng, _context, _sampler, $tmpright, $vn, + $inds, _varinfo + ) + + $f($left_intermediate) + end + end + end + + # If the LHS is a literal, it is always an observation + return quote + $(top...) + if $f isa $(Bijectors.AbstractBijector) + $left_intermediate = inv($f)($left) + $(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left_intermediate, _varinfo) + else + throw(ArgumentError("cannot observe non-invertible reparameterization!!!")) + end + end +end + +""" + generate_dot_tilde_with_reparam(left, right, args, reparam) + +Generate the expression that replaces `@reparam f left .~ right` in the model body. +""" +function generate_dot_tilde_with_reparam(left, right, args, reparam) + @gensym tmpright + top = [:($tmpright = $right), + :($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}} + || throw(ArgumentError($DISTMSG)))] + + if left isa Symbol || left isa Expr + @gensym out vn inds left_intermediate + push!(top, :($vn = $(DynamicPPL.varname2intermediate)($(varname(left))))) + push!(top, :($inds = $(vinds(left)))) + + # `reparam` might be a `Bijectors.AbstractBijector` which we only really want to + # compute once per sampling statement, and so we "cache" the constructor. + @gensym f + push!(top, :($f = $reparam)) + + # It can only be an observation if the LHS is an argument of the model + if vsym(left) in args + @gensym isassumption + return quote + $(top...) + $isassumption = $(DynamicPPL.isassumption(left)) + if $isassumption + $left .= begin + $left_intermediate = $(DynamicPPL.dot_tilde_assume)( + _rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo + ) + + $f.($left_intermediate) + end + else + if $f isa $(Bijectors.AbstractBijector) + $left_intermediate = inv($f).($left) + $(DynamicPPL.dot_tilde_observe)( + _context, _sampler, $tmpright, $left_intermediate, $vn, $inds, _varinfo + ) + else + throw(ArgumentError("cannot observe non-invertible reparameterization!!!")) + end + end + end + end + + return quote + $(top...) + $left .= begin + $left_intermediate = $(DynamicPPL.dot_tilde_assume)( + _rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo + ) + + $f.($left_intermediate) + end + + end + end + + # If the LHS is a literal, it is always an observation + return quote + $(top...) + if $f isa $(Bijectors.AbstractBijector) + $left_intermediate = inv($f).($left) + $(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left_intermediate, _varinfo) + else + throw(ArgumentError("cannot observe non-invertible reparameterization!!!")) + end + end +end + + const FloatOrArrayType = Type{<:Union{AbstractFloat, AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA <: AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true diff --git a/src/varname.jl b/src/varname.jl index ed58e4754..ceb63ceee 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -238,3 +238,14 @@ end @generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T} return s in missings end + +""" + varname2intermediate(vn::VarName) + +Converts `vn` into a `VarName` representing a "intermediate" variable used to obtain `vn`. + +Essentially just adds an underscore to the name, e.g. `x[:, 1]` becomes `x_[:, 1]`. +""" +function varname2intermediate(vn::VarName{sym}) where {sym} + DynamicPPL.VarName{Symbol(sym, "_"), typeof(vn.indexing)}(vn.indexing) +end